-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess_sc_data.py
124 lines (90 loc) · 3.73 KB
/
preprocess_sc_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
"""preprocess_human_challenge_celltype_data.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1enEpTGtjboogV4792YHd1yELQifxHNWz
"""
import scanpy as sc
import numpy as np
import torch
import argparse
from config import load_config, override_config
from path import dataset_paths
"""## Read the file"""
def process_h5ad_data(cfg):
adata = sc.read_h5ad(cfg.DATASET.DATAROOT)
# save the raw counts
adata.layers["counts"] = adata.X
print(adata.X.sum(-1)[:2])
"""## Preprocess the data"""
# filter out cells which have less then 200 genes expressed
sc.pp.filter_cells(adata, min_genes=200)
print(f"Number of cells after filtering: {adata.n_obs}")
sc.pp.filter_genes(adata, min_cells=5)
print(f"Number of genes after filtering: {adata.n_vars}")
# normalize the counts
if cfg.PREPROCESS.NORMALIZE_TOTAL:
sc.pp.normalize_total(adata, target_sum=cfg.PREPROCESS.NORMALIZE_TOTAL)
print(f"Data normalized to {cfg.PREPROCESS.NORMALIZE_TOTAL}")
# log1p normalise
sc.pp.log1p(adata)
print("Data log1p normalized")
# 2k highly variable genes
# as batch key we indicate the patient_id
sc.pp.highly_variable_genes(
adata,
flavor="seurat_v3",
n_top_genes=2000,
layer="counts",
batch_key="patient_id",
subset=True,
)
# for all genes, I do not have to do anything
"""## Accessing data from AnnData and saving it
This section covers how to access the data from the AnnData object and save it.
"""
# to access raw normalized counts matrix, do:
x = adata.X
# this should give you a sparse matrix with which you can work with
# to get the labels per cell, do:
labels_cell_state = adata.obs['cell_state']
labels_cell_type = adata.obs['cell_type']
labels_cell_compartment = adata.obs['cell_compartment']
# this should give you a pandas DF, where the indices correspond to the indices in the sparse matrix of normalized counts
"""## Get Nrs of cell type labels at different levels
There are 3 granularities of cell type labels:
* `cell_state` - most fine-grained
* `cell_type` - medium fine-grained
* `cell_compartment` - least fine-grained
The labels are stored in the `.obs` of adata.
"""
# np array of gene expression
X = x.toarray()
# np array of cell type labels
cell_state_2_id = dict()
for i, name in enumerate(labels_cell_state.unique()):
cell_state_2_id[name] = i
cell_state_labels = labels_cell_state.map(cell_state_2_id).to_numpy()
cell_type_2_id = dict()
for i, name in enumerate(labels_cell_type.unique()):
cell_type_2_id[name] = i
cell_type_labels = labels_cell_type.map(cell_type_2_id).to_numpy()
return {
'inputs': torch.from_numpy(X).float(),
'fine_labels': torch.from_numpy(cell_state_labels).long(),
'coarse_labels': torch.from_numpy(cell_type_labels).long(),
'fine_label_2_name': cell_state_2_id,
'coarse_label_2_name': cell_type_2_id
}
if __name__ == '__main__':
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_file', type=str, required=True)
parser.add_argument('--override_cfg', type=str, nargs='+', required=False)
args = parser.parse_args()
cfg = load_config(args.cfg_file)
cfg = override_config(cfg, args.override_cfg) if args.override_cfg else cfg
cfg.DATASET.DATAROOT = dataset_paths[cfg.DATASET.NAME]
cfg.OUTPUT_DIR = '/dev/null'
data = process_h5ad_data(cfg)
torch.save(data, cfg.DATASET.DATAROOT.replace('.h5ad', '_preprocessed.pth'))