Skip to content

Commit 6cb5813

Browse files
authored
Fix the number of nodes not defined properly (#19482)
1 parent b28b673 commit 6cb5813

File tree

1 file changed

+6
-8
lines changed
  • src/lightning/data/utilities

1 file changed

+6
-8
lines changed

src/lightning/data/utilities/env.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ def detect(cls) -> "_DistributedEnv":
3131
if torch.distributed.is_available() and torch.distributed.is_initialized():
3232
world_size = torch.distributed.get_world_size()
3333
global_rank = torch.distributed.get_rank()
34+
# Note: On multi node CPU, the number of nodes won't be correct.
35+
num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size
36+
if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0:
37+
raise RuntimeError("The world size should be divisible by the number of GPUs.")
3438
else:
3539
world_size = None
3640
global_rank = 0
41+
num_nodes = 1
3742

3843
if world_size is None or world_size == -1:
3944
world_size = 1
4045

41-
# TODO: Add support for other accelerators
42-
num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
43-
44-
if num_nodes > 1:
45-
# validate the world size is divisble by the number of GPUs
46-
assert world_size % torch.cuda.device_count() == 0
47-
48-
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
46+
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
4947

5048
def __repr__(self) -> str:
5149
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"

0 commit comments

Comments
 (0)