Skip to content

Commit 6229c2b

Browse files
committed
update get_pstrata to handle chunks
1 parent 5ee026d commit 6229c2b

File tree

2 files changed

+101
-76
lines changed

2 files changed

+101
-76
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ dependencies = [
5252
"wget>=3.2"
5353
]
5454

55+
[project.optional-dependencies]
56+
gpu = [
57+
"rapids-singlecell 0.11.1"
58+
]
59+
5560
[project.urls]
5661
"Homepage" = "https://github.com/kullrich/oggmap"
5762
"Bug Tracker" = "https://github.com/kullrich/oggmap/issues"

src/oggmap/orthomap2tei.py

Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def get_tei(adata,
560560
:param normalize_total: Normalize counts per cell prior TEI calculation.
561561
:param log1p: Logarithmize the data matrix prior TEI calculation.
562562
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
563-
:param chunk_size: Number of chuns.
563+
:param chunk_size: Number of chunks.
564564
:return: Transcriptome evolutionary index (TEI) values.
565565
566566
:type adata: AnnData
@@ -729,7 +729,7 @@ def get_pmatrix(adata,
729729
:param normalize_total: Normalize counts per cell prior TEI calculation.
730730
:param log1p: Logarithmize the data matrix prior TEI calculation.
731731
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
732-
:param chunk_size: Number of chuns.
732+
:param chunk_size: Number of chunks.
733733
:return: Partial transcriptome evolutionary index (TEI) values.
734734
735735
:type adata: AnnData
@@ -811,7 +811,6 @@ def get_pmatrix(adata,
811811
right=adata.var[kv][adata.var_names.isin(all_id_age_df_keep_subset_chunks[0]['GeneID'])],
812812
left_index=True,
813813
right_index=True)[kv]
814-
print(all_var_names_df_chunks[0])
815814
adata_pmatrix.var['Phylostrata'] = list(pd.merge(left=pd.DataFrame(adata_pmatrix.var_names.values,
816815
columns=['GeneID']),
817816
right=all_var_names_df_chunks[0],
@@ -832,7 +831,8 @@ def get_pstrata(adata,
832831
standard_scale=None,
833832
normalize_total=True,
834833
log1p=True,
835-
target_sum=1e6):
834+
target_sum=1e6,
835+
chunk_size=100000):
836836
"""
837837
This function computes the partial transcriptome evolutionary index (TEI) values combined for each stratum.
838838
@@ -869,6 +869,7 @@ def get_pstrata(adata,
869869
:param normalize_total: Normalize counts per cell prior TEI calculation.
870870
:param log1p: Logarithmize the data matrix prior TEI calculation.
871871
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
872+
:param chunk_size: Number of chunks.
872873
:return: List of two DataFrame. First DataFrame contains the summed partial TEI values per strata.
873874
Second DataFrame contains summed partial TEI values divided by the corresponding global TEI value,
874875
which represent percentage of global TEI per strata.
@@ -886,6 +887,7 @@ def get_pstrata(adata,
886887
:type normalize_total: bool
887888
:type log1p: bool
888889
:type target_sum: float
890+
:type chunk_size: int
889891
:rtype: list
890892
891893
Example
@@ -928,75 +930,83 @@ def get_pstrata(adata,
928930
>>> sns.heatmap(packer19_small_pstrata_grouped[1], annot=True, cmap='viridis')
929931
>>> plt.show()
930932
"""
931-
var_names_df,\
932-
id_age_df_keep_subset,\
933-
adata_counts,\
934-
var_names_subset,\
935-
sumx,\
936-
sumx_recd,\
937-
ps,\
938-
psd = _get_psd(adata=adata,
939-
gene_id=gene_id,
940-
gene_age=gene_age,
941-
keep=keep,
942-
layer=layer,
943-
normalize_total=normalize_total,
944-
log1p=log1p,
945-
target_sum=target_sum)
946-
wmatrix = psd.dot(adata_counts.transpose()).transpose()
947-
pmatrix = sumx_recd.dot(wmatrix)
948-
tei = pmatrix.sum(1)
949-
phylostrata = list(set(id_age_df_keep_subset['Phylostrata']))
950-
pstrata_norm_by_sumx = np.zeros((len(phylostrata), pmatrix.shape[0]))
951-
pstrata_norm_by_pmatrix_sum = np.zeros((len(phylostrata), pmatrix.shape[0]))
952-
for pk_idx, pk in enumerate(phylostrata):
953-
pstrata_norm_by_sumx[pk_idx, ] = np.array(pmatrix[:, id_age_df_keep_subset['Phylostrata'].isin([pk])]
954-
.sum(1)).flatten()
955-
pstrata_norm_by_pmatrix_sum[pk_idx, ] = np.array(pmatrix[:, id_age_df_keep_subset['Phylostrata']
956-
.isin([pk])].sum(1) / tei).flatten()
957-
pstrata_norm_by_sumx_df = pd.DataFrame(pstrata_norm_by_sumx)
958-
pstrata_norm_by_sumx_df['ps'] = phylostrata
959-
pstrata_norm_by_sumx_df.set_index('ps',
960-
inplace=True)
961-
pstrata_norm_by_sumx_df.columns = adata.obs_names
962-
pstrata_norm_by_pmatrix_sum_df = pd.DataFrame(pstrata_norm_by_pmatrix_sum)
963-
pstrata_norm_by_pmatrix_sum_df['ps'] = phylostrata
964-
pstrata_norm_by_pmatrix_sum_df.set_index('ps',
965-
inplace=True)
966-
pstrata_norm_by_pmatrix_sum_df.columns = adata.obs_names
967-
if cumsum:
968-
pstrata_norm_by_sumx_df = pstrata_norm_by_sumx_df.cumsum(0)
969-
pstrata_norm_by_pmatrix_sum_df = pstrata_norm_by_pmatrix_sum_df.cumsum(0)
970-
if group_by_obs is not None:
971-
if adata.obs[group_by_obs].dtype.name == 'category':
972-
obs_group_nan = pd.DataFrame(adata.obs[group_by_obs])
973-
else:
974-
obs_group_nan = pd.DataFrame(adata.obs[group_by_obs].fillna(obs_fillna))
975-
if obs_type == 'mean':
976-
pstrata_norm_by_sumx_df =\
977-
pstrata_norm_by_sumx_df.transpose().groupby(obs_group_nan[group_by_obs]).mean().transpose()
978-
pstrata_norm_by_pmatrix_sum_df =\
979-
pstrata_norm_by_pmatrix_sum_df.transpose().groupby(obs_group_nan[group_by_obs]).mean().transpose()
980-
if obs_type == 'median':
981-
pstrata_norm_by_sumx_df =\
982-
pstrata_norm_by_sumx_df.transpose().groupby(obs_group_nan[group_by_obs]).median().transpose()
983-
pstrata_norm_by_pmatrix_sum_df =\
984-
pstrata_norm_by_pmatrix_sum_df.transpose().groupby(obs_group_nan[group_by_obs]).median().transpose()
985-
if obs_type == 'sum':
986-
pstrata_norm_by_sumx_df =\
987-
pstrata_norm_by_sumx_df.transpose().groupby(obs_group_nan[group_by_obs]).sum().transpose()
988-
pstrata_norm_by_pmatrix_sum_df =\
989-
pstrata_norm_by_pmatrix_sum_df.transpose().groupby(obs_group_nan[group_by_obs]).sum().transpose()
990-
if obs_type == 'min':
991-
pstrata_norm_by_sumx_df =\
992-
pstrata_norm_by_sumx_df.transpose().groupby(obs_group_nan[group_by_obs]).min().transpose()
993-
pstrata_norm_by_pmatrix_sum_df =\
994-
pstrata_norm_by_pmatrix_sum_df.transpose().groupby(obs_group_nan[group_by_obs]).min().transpose()
995-
if obs_type == 'max':
996-
pstrata_norm_by_sumx_df =\
997-
pstrata_norm_by_sumx_df.transpose().groupby(obs_group_nan[group_by_obs]).max().transpose()
998-
pstrata_norm_by_pmatrix_sum_df =\
999-
pstrata_norm_by_pmatrix_sum_df.transpose().groupby(obs_group_nan[group_by_obs]).max().transpose()
933+
adata_pstrata_norm_by_sumx_df_chunks = []
934+
adata_pstrata_norm_by_pmatrix_sum_df_chunks = []
935+
for i in range(0, adata.shape[0], chunk_size):
936+
adata_subset = adata[i:i+chunk_size]
937+
var_names_df_chunk,\
938+
id_age_df_keep_subset_chunk,\
939+
adata_counts_chunk,\
940+
var_names_subset_chunk,\
941+
sumx_chunk,\
942+
sumx_recd_chunk,\
943+
ps_chunk,\
944+
psd_chunk = _get_psd(adata=adata_subset,
945+
gene_id=gene_id,
946+
gene_age=gene_age,
947+
keep=keep,
948+
layer=layer,
949+
normalize_total=normalize_total,
950+
log1p=log1p,
951+
target_sum=target_sum)
952+
wmatrix_chunk = psd_chunk.dot(adata_counts_chunk.transpose()).transpose()
953+
pmatrix_chunk = sumx_recd_chunk.dot(wmatrix_chunk)
954+
tei_chunk = pmatrix_chunk.sum(1)
955+
phylostrata_chunk = list(set(id_age_df_keep_subset_chunk['Phylostrata']))
956+
pstrata_norm_by_sumx_chunk = np.zeros((len(phylostrata_chunk), pmatrix_chunk.shape[0]))
957+
pstrata_norm_by_pmatrix_sum_chunk = np.zeros((len(phylostrata_chunk), pmatrix_chunk.shape[0]))
958+
for pk_idx, pk in enumerate(phylostrata_chunk):
959+
pstrata_norm_by_sumx_chunk[pk_idx, ] = np.array(pmatrix_chunk[:, id_age_df_keep_subset_chunk['Phylostrata'].isin([pk]).values]
960+
.sum(1)).flatten()
961+
pstrata_norm_by_pmatrix_sum_chunk[pk_idx, ] = np.array(pmatrix_chunk[:, id_age_df_keep_subset_chunk['Phylostrata']
962+
.isin([pk]).values].sum(1) / tei_chunk).flatten()
963+
pstrata_norm_by_sumx_df_chunk = pd.DataFrame(pstrata_norm_by_sumx_chunk)
964+
pstrata_norm_by_sumx_df_chunk['ps'] = phylostrata_chunk
965+
pstrata_norm_by_sumx_df_chunk.set_index('ps',
966+
inplace=True)
967+
pstrata_norm_by_sumx_df_chunk.columns = adata_subset.obs_names
968+
pstrata_norm_by_pmatrix_sum_df_chunk = pd.DataFrame(pstrata_norm_by_pmatrix_sum_chunk)
969+
pstrata_norm_by_pmatrix_sum_df_chunk['ps'] = phylostrata_chunk
970+
pstrata_norm_by_pmatrix_sum_df_chunk.set_index('ps',
971+
inplace=True)
972+
pstrata_norm_by_pmatrix_sum_df_chunk.columns = adata_subset.obs_names
973+
if cumsum:
974+
pstrata_norm_by_sumx_df_chunk = pstrata_norm_by_sumx_df_chunk.cumsum(0)
975+
pstrata_norm_by_pmatrix_sum_df_chunk = pstrata_norm_by_pmatrix_sum_df_chunk.cumsum(0)
976+
if group_by_obs is not None:
977+
if adata_subset.obs[group_by_obs].dtype.name == 'category':
978+
obs_group_nan = pd.DataFrame(adata_subset.obs[group_by_obs])
979+
else:
980+
obs_group_nan = pd.DataFrame(adata_subset.obs[group_by_obs].fillna(obs_fillna))
981+
if obs_type == 'mean':
982+
pstrata_norm_by_sumx_df_chunk =\
983+
pstrata_norm_by_sumx_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).mean().transpose()
984+
pstrata_norm_by_pmatrix_sum_df_chunk =\
985+
pstrata_norm_by_pmatrix_sum_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).mean().transpose()
986+
if obs_type == 'median':
987+
pstrata_norm_by_sumx_df_chunk =\
988+
pstrata_norm_by_sumx_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).median().transpose()
989+
pstrata_norm_by_pmatrix_sum_df_chunk =\
990+
pstrata_norm_by_pmatrix_sum_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).median().transpose()
991+
if obs_type == 'sum':
992+
pstrata_norm_by_sumx_df_chunk =\
993+
pstrata_norm_by_sumx_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).sum().transpose()
994+
pstrata_norm_by_pmatrix_sum_df_chunk =\
995+
pstrata_norm_by_pmatrix_sum_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).sum().transpose()
996+
if obs_type == 'min':
997+
pstrata_norm_by_sumx_df_chunk =\
998+
pstrata_norm_by_sumx_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).min().transpose()
999+
pstrata_norm_by_pmatrix_sum_df_chunk =\
1000+
pstrata_norm_by_pmatrix_sum_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).min().transpose()
1001+
if obs_type == 'max':
1002+
pstrata_norm_by_sumx_df_chunk =\
1003+
pstrata_norm_by_sumx_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).max().transpose()
1004+
pstrata_norm_by_pmatrix_sum_df_chunk =\
1005+
pstrata_norm_by_pmatrix_sum_df_chunk.transpose().groupby(obs_group_nan[group_by_obs]).max().transpose()
1006+
adata_pstrata_norm_by_sumx_df_chunks.append(pstrata_norm_by_sumx_df_chunk)
1007+
adata_pstrata_norm_by_pmatrix_sum_df_chunks.append(pstrata_norm_by_pmatrix_sum_df_chunk)
1008+
pstrata_norm_by_sumx_df = pd.concat(adata_pstrata_norm_by_sumx_df_chunks, axis=1)
1009+
pstrata_norm_by_pmatrix_sum_df = pd.concat(adata_pstrata_norm_by_pmatrix_sum_df_chunks, axis=1)
10001010
if standard_scale is not None:
10011011
if standard_scale == 0:
10021012
pstrata_norm_by_sumx_df = pstrata_norm_by_sumx_df.apply(_min_max_to_01,
@@ -1054,7 +1064,8 @@ def get_ematrix(adata,
10541064
standard_scale=None,
10551065
normalize_total=True,
10561066
log1p=True,
1057-
target_sum=1e6):
1067+
target_sum=1e6,
1068+
chunk_size=100000):
10581069
"""
10591070
This function computes expression profiles for all genes or group of genes 'group_by_var' (default: None).
10601071
@@ -1107,6 +1118,7 @@ def get_ematrix(adata,
11071118
:param normalize_total: Normalize counts per cell.
11081119
:param log1p: Logarithmize the data matrix.
11091120
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
1121+
:param chunk_size: Number of chunks.
11101122
:return: Expression profile DataFrame.
11111123
11121124
:type adata: AnnData
@@ -1121,6 +1133,7 @@ def get_ematrix(adata,
11211133
:type normalize_total: bool
11221134
:type log1p: bool
11231135
:type target_sum: float
1136+
:type chunk_size: int
11241137
:rtype: pandas.DataFrame
11251138
11261139
Example
@@ -1312,7 +1325,8 @@ def get_rematrix(adata,
13121325
standard_scale=None,
13131326
normalize_total=True,
13141327
log1p=True,
1315-
target_sum=1e6):
1328+
target_sum=1e6,
1329+
chunk_size=100000):
13161330
"""
13171331
This function computes relative expression profiles.
13181332
@@ -1364,6 +1378,7 @@ def get_rematrix(adata,
13641378
:param normalize_total: Normalize counts per cell prior TEI calculation.
13651379
:param log1p: Logarithmize the data matrix prior TEI calculation.
13661380
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
1381+
:param chunk_size: Number of chunks.
13671382
:return: Relative expression profile DataFrame.
13681383
13691384
:type adata: AnnData
@@ -1380,6 +1395,7 @@ def get_rematrix(adata,
13801395
:type normalize_total: bool
13811396
:type log1p: bool
13821397
:type target_sum: float
1398+
:type chunk_size: int
13831399
:rtype: pandas.DataFrame
13841400
13851401
Example
@@ -1695,7 +1711,8 @@ def mergeby_from_counts(adata,
16951711
max_expr=None,
16961712
normalize_total=False,
16971713
log1p=False,
1698-
target_sum=1e6):
1714+
target_sum=1e6,
1715+
chunk_size=100000):
16991716
"""
17001717
This function groups all counts of an existing AnnData object as an array based on variable or observation groups.
17011718
The resulting pandas.DataFrame can be used to e.g. apply statistics or visualize the groups more easily.
@@ -1712,6 +1729,7 @@ def mergeby_from_counts(adata,
17121729
:param normalize_total: Normalize counts per cell.
17131730
:param log1p: Logarithmize the data matrix.
17141731
:param target_sum: After normalization, each observation (cell) has a total count equal to target_sum.
1732+
:param chunk_size: Number of chunks.
17151733
:return: List of three DataFrame. First DataFrame contains the grouped data (each cell contains a numpy.ndarray).
17161734
Second DataFrame contains original variable and observation assignment and groupings.
17171735
@@ -1727,6 +1745,7 @@ def mergeby_from_counts(adata,
17271745
:type normalize_total: bool
17281746
:type log1p: bool
17291747
:type target_sum: float
1748+
:type chunk_size: int
17301749
:rtype: list
17311750
17321751
Example
@@ -1926,6 +1945,7 @@ def get_e50(adata,
19261945
:param normalize_total:
19271946
:param log1p:
19281947
:param target_sum:
1948+
:param chunk_size: Number of chunks.
19291949
:param min_expr:
19301950
:param max_expr:
19311951
:return:

0 commit comments

Comments
 (0)