Skip to content

Commit 3260675

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7789d64 commit 3260675

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

examples/fabric/tensor_parallel/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4+
from data import RandomTokenDataset
45
from lightning.fabric.strategies import ModelParallelStrategy
56
from model import ModelArgs, Transformer
67
from parallelism import parallelize
78
from torch.distributed.tensor.parallel import loss_parallel
89
from torch.utils.data import DataLoader
910

10-
from data import RandomTokenDataset
11-
1211

1312
def train():
1413
strategy = ModelParallelStrategy(

examples/pytorch/tensor_parallel/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import lightning as L
22
import torch
33
import torch.nn.functional as F
4+
from data import RandomTokenDataset
45
from lightning.pytorch.strategies import ModelParallelStrategy
56
from model import ModelArgs, Transformer
67
from parallelism import parallelize
78
from torch.distributed.tensor.parallel import loss_parallel
89
from torch.utils.data import DataLoader
910

10-
from data import RandomTokenDataset
11-
1211

1312
class Llama3(L.LightningModule):
1413
def __init__(self):

0 commit comments

Comments
 (0)