Skip to content

Commit e2257a0

Browse files
bkmartinjrryan-williams
authored andcommitted
fix RNG state bug in shuffle; add multi-worker notebook
1 parent 4a08504 commit e2257a0

File tree

4 files changed

+464
-144
lines changed

4 files changed

+464
-144
lines changed

notebooks/tutorial_lightning.ipynb

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,25 @@
2626
},
2727
{
2828
"cell_type": "code",
29-
"execution_count": 4,
29+
"execution_count": 1,
3030
"metadata": {},
31-
"outputs": [],
31+
"outputs": [
32+
{
33+
"name": "stderr",
34+
"output_type": "stream",
35+
"text": [
36+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
37+
"################################################################################\n",
38+
"WARNING!\n",
39+
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
40+
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
41+
"to learn more and leave feedback.\n",
42+
"################################################################################\n",
43+
"\n",
44+
" deprecation_warning()\n"
45+
]
46+
}
47+
],
3248
"source": [
3349
"import pytorch_lightning as pl\n",
3450
"import torch\n",
@@ -58,7 +74,6 @@
5874
" obs_column_names=[\"cell_type\"],\n",
5975
" batch_size=128,\n",
6076
" shuffle=True,\n",
61-
" seed=12345,\n",
6277
" )"
6378
]
6479
},
@@ -71,7 +86,7 @@
7186
},
7287
{
7388
"cell_type": "code",
74-
"execution_count": 5,
89+
"execution_count": 2,
7590
"metadata": {},
7691
"outputs": [],
7792
"source": [
@@ -100,7 +115,6 @@
100115
" predictions = torch.argmax(probabilities, axis=1)\n",
101116
"\n",
102117
" # Compute loss\n",
103-
" # y_batch = y_batch.flatten()\n",
104118
" y_batch = torch.from_numpy(\n",
105119
" self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n",
106120
" ).to(self.device)\n",
@@ -130,7 +144,7 @@
130144
},
131145
{
132146
"cell_type": "code",
133-
"execution_count": 6,
147+
"execution_count": 3,
134148
"metadata": {},
135149
"outputs": [
136150
{
@@ -140,6 +154,7 @@
140154
"GPU available: True (cuda), used: True\n",
141155
"TPU available: False, using: 0 TPU cores\n",
142156
"HPU available: False, using: 0 HPUs\n",
157+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
143158
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
144159
"\n",
145160
" | Name | Type | Params | Mode \n",
@@ -153,32 +168,15 @@
153168
"2.905 Total estimated model params size (MB)\n",
154169
"2 Modules in train mode\n",
155170
"0 Modules in eval mode\n",
156-
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n",
157-
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
158-
"################################################################################\n",
159-
"WARNING!\n",
160-
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
161-
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
162-
"to learn more and leave feedback.\n",
163-
"################################################################################\n",
164-
"\n",
165-
" deprecation_warning()\n",
166-
"/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
167-
"################################################################################\n",
168-
"WARNING!\n",
169-
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
170-
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
171-
"to learn more and leave feedback.\n",
172-
"################################################################################\n",
173-
"\n",
174-
" deprecation_warning()\n"
171+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
172+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n"
175173
]
176174
},
177175
{
178176
"name": "stdout",
179177
"output_type": "stream",
180178
"text": [
181-
"Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.87it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]"
179+
"Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]"
182180
]
183181
},
184182
{
@@ -192,14 +190,12 @@
192190
"name": "stdout",
193191
"output_type": "stream",
194192
"text": [
195-
"Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.86it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]\n"
193+
"Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n"
196194
]
197195
}
198196
],
199197
"source": [
200-
"dataloader = soma_ml.experiment_dataloader(\n",
201-
" experiment_dataset, num_workers=2, persistent_workers=True\n",
202-
")\n",
198+
"dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n",
203199
"\n",
204200
"# The size of the input dimension is the number of genes\n",
205201
"input_dim = experiment_dataset.shape[1]\n",
@@ -213,11 +209,7 @@
213209
")\n",
214210
"\n",
215211
"# Define the PyTorch Lightning Trainer\n",
216-
"trainer = pl.Trainer(\n",
217-
" max_epochs=20,\n",
218-
" # accelerator=args.accelerator,\n",
219-
" # strategy=\"ddp\",\n",
220-
")\n",
212+
"trainer = pl.Trainer(max_epochs=20)\n",
221213
"\n",
222214
"# set precision\n",
223215
"torch.set_float32_matmul_precision(\"high\")\n",

notebooks/tutorial_multiworker.ipynb

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Multi-process training\n",
8+
"\n",
9+
"Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n",
10+
"* using the `torch.utils.data.DataLoader` with 1 or more worker (ie., with an argument of `n_workers=1` or greater)\n",
11+
"* using a multi-process training configuration, such as `DistributedDataParallel`\n",
12+
"\n",
13+
"In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n",
14+
"\n",
15+
"1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n",
16+
"2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n",
17+
"\n",
18+
"\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": 1,
24+
"metadata": {},
25+
"outputs": [
26+
{
27+
"name": "stderr",
28+
"output_type": "stream",
29+
"text": [
30+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
31+
"################################################################################\n",
32+
"WARNING!\n",
33+
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
34+
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
35+
"to learn more and leave feedback.\n",
36+
"################################################################################\n",
37+
"\n",
38+
" deprecation_warning()\n"
39+
]
40+
}
41+
],
42+
"source": [
43+
"import tiledbsoma_ml as soma_ml\n",
44+
"import torch\n",
45+
"from sklearn.preprocessing import LabelEncoder\n",
46+
"\n",
47+
"import tiledbsoma as soma\n",
48+
"\n",
49+
"CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n",
50+
"\n",
51+
"experiment = soma.open(\n",
52+
" CZI_Census_Homo_Sapiens_URL,\n",
53+
" context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n",
54+
")\n",
55+
"obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n",
56+
"\n",
57+
"with experiment.axis_query(\n",
58+
" measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n",
59+
") as query:\n",
60+
" obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n",
61+
" cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n",
62+
"\n",
63+
" experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n",
64+
" query,\n",
65+
" X_name=\"raw\",\n",
66+
" obs_column_names=[\"cell_type\"],\n",
67+
" batch_size=128,\n",
68+
" shuffle=True,\n",
69+
" )\n",
70+
" "
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 2,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"\n",
80+
"class LogisticRegression(torch.nn.Module):\n",
81+
" def __init__(self, input_dim, output_dim):\n",
82+
" super(LogisticRegression, self).__init__() # noqa: UP008\n",
83+
" self.linear = torch.nn.Linear(input_dim, output_dim)\n",
84+
"\n",
85+
" def forward(self, x):\n",
86+
" outputs = torch.sigmoid(self.linear(x))\n",
87+
" return outputs\n",
88+
" \n",
89+
"\n",
90+
"def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n",
91+
" model.train()\n",
92+
" train_loss = 0\n",
93+
" train_correct = 0\n",
94+
" train_total = 0\n",
95+
"\n",
96+
" for X_batch, y_batch in train_dataloader:\n",
97+
" optimizer.zero_grad()\n",
98+
"\n",
99+
" X_batch = torch.from_numpy(X_batch).float().to(device)\n",
100+
"\n",
101+
" # Perform prediction\n",
102+
" outputs = model(X_batch)\n",
103+
"\n",
104+
" # Determine the predicted label\n",
105+
" probabilities = torch.nn.functional.softmax(outputs, 1)\n",
106+
" predictions = torch.argmax(probabilities, axis=1)\n",
107+
"\n",
108+
" # Compute the loss and perform back propagation\n",
109+
" y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n",
110+
" train_correct += (predictions == y_batch).sum().item()\n",
111+
" train_total += len(predictions)\n",
112+
"\n",
113+
" loss = loss_fn(outputs, y_batch.long())\n",
114+
" train_loss += loss.item()\n",
115+
" loss.backward()\n",
116+
" optimizer.step()\n",
117+
"\n",
118+
" train_loss /= train_total\n",
119+
" train_accuracy = train_correct / train_total\n",
120+
" return train_loss, train_accuracy"
121+
]
122+
},
123+
{
124+
"cell_type": "markdown",
125+
"metadata": {},
126+
"source": [
127+
"## Multi-worker DataLoader\n",
128+
"\n",
129+
"If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.\n",
130+
"\n",
131+
"The same approach should be taken for parallel training, e.g., when using DDP or DP.\n",
132+
"\n",
133+
"*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes."
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 3,
139+
"metadata": {},
140+
"outputs": [
141+
{
142+
"name": "stderr",
143+
"output_type": "stream",
144+
"text": [
145+
"switching torch multiprocessing start method from \"fork\" to \"spawn\"\n",
146+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
147+
"################################################################################\n",
148+
"WARNING!\n",
149+
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
150+
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
151+
"to learn more and leave feedback.\n",
152+
"################################################################################\n",
153+
"\n",
154+
" deprecation_warning()\n",
155+
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
156+
"################################################################################\n",
157+
"WARNING!\n",
158+
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
159+
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
160+
"to learn more and leave feedback.\n",
161+
"################################################################################\n",
162+
"\n",
163+
" deprecation_warning()\n"
164+
]
165+
},
166+
{
167+
"name": "stdout",
168+
"output_type": "stream",
169+
"text": [
170+
"Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n",
171+
"Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n",
172+
"Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n",
173+
"Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n",
174+
"Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n",
175+
"Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n",
176+
"Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n",
177+
"Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n",
178+
"Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n",
179+
"Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n",
180+
"Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n",
181+
"Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n",
182+
"Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n",
183+
"Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n",
184+
"Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n",
185+
"Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n",
186+
"Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n",
187+
"Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n",
188+
"Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n",
189+
"Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n"
190+
]
191+
}
192+
],
193+
"source": [
194+
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
195+
"\n",
196+
"# The size of the input dimension is the number of genes\n",
197+
"input_dim = experiment_dataset.shape[1]\n",
198+
"\n",
199+
"# The size of the output dimension is the number of distinct cell_type values\n",
200+
"output_dim = len(cell_type_encoder.classes_)\n",
201+
"\n",
202+
"model = LogisticRegression(input_dim, output_dim).to(device)\n",
203+
"loss_fn = torch.nn.CrossEntropyLoss()\n",
204+
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n",
205+
"\n",
206+
"\n",
207+
"# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n",
208+
"# that a different shuffle is applied on each epoch.\n",
209+
"experiment_dataloader = soma_ml.experiment_dataloader(\n",
210+
" experiment_dataset, num_workers=2, persistent_workers=True\n",
211+
")\n",
212+
"\n",
213+
"for epoch in range(20):\n",
214+
" experiment_dataset.set_epoch(epoch)\n",
215+
" train_loss, train_accuracy = train_epoch(\n",
216+
" model, experiment_dataloader, loss_fn, optimizer, device\n",
217+
" )\n",
218+
" print(\n",
219+
" f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n",
220+
" )"
221+
]
222+
}
223+
],
224+
"metadata": {
225+
"kernelspec": {
226+
"display_name": "toymodel",
227+
"language": "python",
228+
"name": "python3"
229+
},
230+
"language_info": {
231+
"codemirror_mode": {
232+
"name": "ipython",
233+
"version": 3
234+
},
235+
"file_extension": ".py",
236+
"mimetype": "text/x-python",
237+
"name": "python",
238+
"nbconvert_exporter": "python",
239+
"pygments_lexer": "ipython3",
240+
"version": "3.11.9"
241+
}
242+
},
243+
"nbformat": 4,
244+
"nbformat_minor": 2
245+
}

0 commit comments

Comments
 (0)