Skip to content

Commit

Permalink
Clean up and comment out
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer committed Feb 25, 2025
1 parent 0b09fd4 commit 71b78dc
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 82 deletions.
21 changes: 6 additions & 15 deletions examples/ibm_rescaling/rescaling_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
save_distributed_state_dict,
)

# This example script validates the rescaling behavior of the ibm rescalable distributed datasets.
# This example script validates the rescaling behavior of the ScalableReader.
# On first run, creates a dummy dataset and saves a distributed checkpoint at the desired location.
# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers)
# and verifies that all remaining data is covered by the time the epoch finishes.

# Example usage:
# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6

# Do not change the batch size or number of steps between the first and second runs!
# Do not change the number of steps between the first and second runs!

parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints")
parser.add_argument("--ckpt_path", type=str, default="./rescale_test")
Expand Down Expand Up @@ -85,10 +85,6 @@
# Wrap in StatefulDataLoader
data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers)

# TODO: debug: can't change n_workers when reloading - keyerror
# TODO: debug: going from 4/2/2 gpu/bsize/workers to 6/1/2 causes epoch not to finish


# If checkpoint does not exist, create it
ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state")
if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0:
Expand Down Expand Up @@ -125,30 +121,25 @@
if rank == 0:
print("Checkpoint detected!")
load_distributed_state_dict(data, ckpt_path, mesh)
# print("FINAL")
# time.sleep(10000)
avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")).tolist()

# Finish out epoch (extra 2*ceil(n_items/n_shards) steps to account for worst-case uneven finishing times)
# Finish out epoch (extra 2*ceil(ceil(n_items/n_shards)/bsize) steps to account for worst-case uneven finishing times)
vals = []
n_steps = (
math.ceil((3000 - len(avoid)) / (world_size * args.b_size))
+ 2 * math.ceil(3000/args.logical_shards)
+ 2 * math.ceil(math.ceil(3000/args.logical_shards)/args.b_size)
)
for i, inp in enumerate(data):
vals.append(inp)
if i == n_steps:
break
vals.append(inp)
vals = torch.cat(vals)
# Get all vals onto each rank
vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor()

# Diag save
# Save final state dicts for diagnostic purposes
os.makedirs(os.path.join(args.ckpt_path, "diag"), exist_ok=True)
torch.save(data.state_dict(), os.path.join(args.ckpt_path, "diag", f"loader_state_{rank}.pth"))
# if rank == 0:
# torch.save(vals, os.path.join(args.ckpt_path, "diag", "vals.pth"))
# time.sleep(10)

# Perform data coverage check on rank 0 only
if rank == 0:
Expand Down
152 changes: 85 additions & 67 deletions torchdata/stateful_dataloader/ibm_rescalable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,40 @@
from .stateful_dataloader import StatefulDataLoader

"""
TODO: UPDATE THIS FOR SCALABLEREADER
The following distributed dataloaders are designed around 3 main principles:
1. Efficient, asynchronous operation. Workers on different devices do not communicate.
2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator
loading from disk and additional layers adding levels of post-processing (shuffling,
packing, padding, rescaling, etc.).
3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal
state that can be written/read on disk via implemented recursive `state_dict()` and
`load_state_dict()` calls. Any values that should be saved to state can be designated
'state_params' and will be automatically included in the state dict. States must be
valid targets of torch.tensor().
4. Rescalability. Users can save and load checkpoints to/from different numbers of workers
without losing the global state. This is accomplished by splitting the global state over
a predefined large number of small partitions, each of which tracks its own individual
state. Rescaling is accomplished by re-distributing these shards over the physical workers.
Our loaders obey the following type hierarchy:
torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset.
`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a
single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times,
then applying some sort of post-processing and yielding the result. Users build data processing
pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers,
which is then passed to the torch DataLoader.
It is likely that this can be merged into the existing Nodes structure, but we leave this for
future work, for now.
This file borrows the StatefulDataset framework from the IBM fms-fsdp repo to implement rescalable data
loading. This framework is analogous to the existing torchdata nodes framework and will be converted
in the future.
Rescalability is implemented at the base level - you must use this layer to interface with a collection
of indexable files directly. The ScalableReader then yields data values like an iterator. These values
are not shuffled.
ScalableReader interfaces with indexable files via custom FileHandlers. These FileHandlers implement basic
file operations such as file type checking, opening, indexing, and slicing. By implementing these basic
operations, users can add support for arbitrary file types.
Rescalability is implemented by splitting data into a large number of logical shards, which are then
allocated over the set of dataloader workers. We assume that logical shards vastly outnumber workers,
such that when workers do not divide logical shards evenly, the off-by-one allocations don't matter and
workers still finish their epochs at roughly the same time. Files are assigned to logical shards
fractionally and based on file size, such that each shard contains roughly equal amounts of data, and as
few individual files as possible. This minimizes the number of file pulls.
ScalableReaders step through a single active logical shard at a time, to minimize overhead. This behavior
can be relaxed later.
When rescaling to a different number of workers, the logical shard progress counters are aggregated
globally onto each ScalableReader. Then, completed and incomplete logical shards are re-allocated
separately, to ensure that each worker receives roughly the same ratio of seen to unseen data in the
current epoch. This allows us to scale from any number of workers to any other number.
State dicts must be saved using DCP in current code, but this can also be relaxed in future for cases when
rescaling is not required. Rescaling will always require DCP.
"""


#### ------------------------- BORROWED FROM IBM FMS-FSDP ------------------------- ####

class _StatefulDataset(data.IterableDataset):
"""
Stub for stateful datasets, extends data.IterableDataset with state_dict methods.
Expand Down Expand Up @@ -177,7 +180,7 @@ def state_dict(self):
return out


#### ------------------------- FILE READERS ------------------------- ####
#### ------------------------- FILE HANDLERS ------------------------- ####


class _ShardFileHandler:
Expand Down Expand Up @@ -296,10 +299,18 @@ def __iter__(self):
yield self.aug_fn(out)


#### ------------------------- NEW CODE STARTS HERE ------------------------- ####


class ScalableReader(_StatefulDataset):
"""
Maintains shared logical shards but opens them one at a time. Completely repartitions
unseen shards only when rescaling.
Maintains n x 5 state buffer where n is the number of logical shards owned by this worker,
and 5 is the number of relevant data fields per-shard. Finishes shards with the lowest
visit count before continuing into new epoch. When rescaling, re-allocates visited / unvisited
shards in the current epoch separately, so that each new worker finishes the epoch at around
the same time.
Currently does not shuffle docs within shards/files, but this can be added later.
"""

def __init__(
Expand All @@ -318,46 +329,50 @@ def __init__(
verbose: bool = False,
):
super().__init__(datapath, rank, worldsize)
self.seed = seed
self.seed = seed # Currently unused
self.datapath = datapath
self.filehandler = filehandler()
self.min_length = min_length
self.min_length = min_length # Ignore any docs shorter than this
assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer"
self.chunksize = max_chunksize
self.eos = delimiter_token
self.bos = bos_token
self.drop = strip_tokens
self.chunksize = max_chunksize # Yield chunks at a time if doc is longer than this
self.eos = delimiter_token # Inserted between each doc
self.bos = bos_token # Inserted before each doc (optional)
self.drop = strip_tokens # Tokens to drop from begin/end of doc (replaced by above delimiter/bos)
self.n_logical_shards = n_logical_shards
self.verbose = verbose
self.verbose = verbose # Currently unused

# Position
self.reader = None
self.cur_file = None

# Setup flags
self.is_setup = False
self.filesizes = None # [[filenames], [filesizes]] CONSTRUCTED PRE ITER IF NOT LOADED
self.shard_states = None # shardid, file pos, doc pos, chunk pos, epoch RESHARD
self.filesizes = None # [[filenames], [filesizes]] (constructed pre-iter if not loaded from ckp)
self.shard_states = None # shardid, file pos, doc pos, chunk pos, epoch (reshardable state buffer)

# TODO: add handling to prevent zero-length allocations

def _get_shard_breakdown(self, rank, nshards):
# Find highest fileid still smaller than start
"""
Retrieve the set of (fractional) files assigned to a given logical shard
"""
# Find first doc included in the current shard
sizelist = torch.tensor(self.filesizes[1])
sizelist = sizelist/sizelist.float().mean()
sizelist = sizelist/sizelist.float().sum()
cum_sizelist = sizelist.cumsum(0)
start_frac = rank/nshards*len(sizelist)
start_frac = rank/nshards
start_id = len(sizelist) - cum_sizelist.gt(start_frac).sum().item()
# For each doc, assign relevant fractional ownership
start = start_frac
end = (rank+1)/nshards*len(sizelist)
end = (rank+1)/nshards
my_files = [] # fileid, start%, end%
for i, (size, cumsize_incl) in enumerate(
zip(sizelist[start_id:].tolist(), cum_sizelist[start_id:].tolist())
):
id = start_id + i
cumsize = cumsize_incl - size
if cumsize > end:
# No more files to include, stop early
break
elif cumsize <= end and cumsize_incl >= start:
my_files.append([
Expand All @@ -368,6 +383,10 @@ def _get_shard_breakdown(self, rank, nshards):
return my_files

def setup(self):
"""
Perform any rank-dependent setup. This operation is deferred from __init__ to support
multiple workers in the dataloader.
"""
if not self.is_setup:
# Get your adjusted rank and worldsize
super().setup()
Expand All @@ -381,12 +400,16 @@ def setup(self):
# Set up logical shard states (may be overwritten later by ckp load)
self.shard_states = torch.zeros(math.ceil(self.n_logical_shards / self.worldsize), 5, dtype=torch.int)
self.shard_states[:len(my_shards), 0] = torch.tensor(my_shards)

# Pad shard state if this worker is off by one. Id is -1 and visit count is inf.
self.shard_states[len(my_shards):, 0] = -1
self.shard_states[len(my_shards):, 4] = torch.iinfo(torch.int).max

def _pre_iter(self):
# Run after loading checkpoint, before iterating

"""
Construct index of data files and their filesizes.
This is saved/loaded in subsequent checkpoints to avoid re-indexing the entire dataset repeatedly.
"""
# Assemble set of available shard files, if nonexistant
if self.filesizes is None:
# Find all legal files
Expand Down Expand Up @@ -480,7 +503,7 @@ def __iter__(self):

def state_dict(self):
self.setup()
# Values to save: shard states (shard/repl), filesizes (single/repl)
# Values to save: shard states, filesizes
# Deepcopy required to prevent in-place modification from later prefetches
out = {self.statename("shard_states", rank=self.rank): self.shard_states}
if self.rank==0:
Expand All @@ -489,24 +512,18 @@ def state_dict(self):

def load_state_dict(self, state_dict):
self.setup()
# Load back shard states (global), filesizes (all)
shard_states = state_dict[self.statename("shard_states")]
# Load back shard states and file sizes
shard_states = state_dict[self.statename("shard_states")] # list[tensor]
file_info = state_dict[self.statename("file_info")]
# print(shard_states.size(0), self.worldsize)
# if self.rank == 0:
# print(shard_states.shape, shard_states)
# torch.save(shard_states, "/gpfs/davis/test.pth")
if len(shard_states) == self.worldsize:
self.filesizes = file_info
self.shard_states = shard_states[self.rank]
else:
# shard_states = [s[0] for s in shard_states.split(1)] # [w] n 5
# shard_states = torch.cat(shard_states, dim=0) # wn 5
# Sort shards by epoch count
shard_states = torch.cat(shard_states, dim=0)
sorted, indices = torch.sort(shard_states[:,4], descending=True, stable=True)
shard_states = shard_states[indices]
# Strip out dummy shards
# Strip out dummy padding shards
n_dummies = sorted.eq(torch.iinfo(torch.int).max).sum()
shard_states = shard_states[n_dummies:] # n_logical 5
assert len(shard_states) == self.n_logical_shards, f"Number of shards {len(shard_states)} does not match specified {self.n_logical_shards}"
Expand All @@ -532,6 +549,7 @@ def load_state_dict(self, state_dict):
] for i in range(self.worldsize)
]
# Reverse sort incomplete shards by length
# Minimizes padding by overallocating incomplete shards to underallocated complete shards
incomplete_shards.sort(key=len, reverse=True)

# Pull out shard allocation for this worker
Expand All @@ -555,8 +573,10 @@ def load_state_dict(self, state_dict):

def __pop_dstate(state, device_mesh, placements, create_dtensor=False):
"""
Removes worker states from the StatefulDataLoader state dict, and assembles them
into a separate list of dicts of dtensors for distributed checkpointing.
Removes worker states from the StatefulDataLoader state dict, and fuses them into a single dict
(assuming no key overlap, which we currently guarantee by adding a rank to each worker's shardstate)
Includes old dtensor logic but currently not used (as no state buffers are getting resharded
straightforwardly). This will likely change in the future.
"""
dstate = state["_snapshot"]["_worker_snapshots"]
dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))]
Expand Down Expand Up @@ -586,8 +606,7 @@ def save_distributed_state_dict(
"""
Retrieves dataloader state dict, and separates worker states from loader state.
Loader state is not rescalable, and is discarded when rescaling.
Rescalable worker states are compiled into a dtensor across ranks, and saved
using pytorch distributed checkpointing.
Saves dict using DCP.
"""
state = deepcopy(loader.state_dict())
dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)], True)
Expand All @@ -609,11 +628,14 @@ def load_distributed_state_dict(
device_mesh: dist.DeviceMesh,
):
"""
Retrieves dataloader state dict, and separates worker states from loader state.
Retrieves dataloader state dict using DCP, and separates worker states from loader state.
If not rescaling, load saved dataloader state.
Rescalable worker states are retrieved using pytorch distributed checkpointing.
States are replicated over workers, and ScalableReader will handle
partitioning and re-assignment of available states into logical ranks.
Loading back to the same number of workers results in key overlap for 'state', so I suspect
that any rank-dependent dataloader state is being lost or overwritten in this case.
TODO: verify/fix
"""
base = loader.state_dict()
nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"]
Expand All @@ -625,12 +647,9 @@ def load_distributed_state_dict(
keys=set(["state", "dstate"]),
storage_reader = reader,
) # NOTE: assumes inp["state"] is same across all devices
# checkpoint.load_state_dict(
# inp,
# reader,
# )
dstate = inp["dstate"]
# Re-pack the set of rankX args
# NOTE: this is the step currently breaking the no-DCP path
keys = list(dstate.keys())
ranked_state = {k:dstate.pop(k) for k in keys if "rank" in k}
ranked_keylist = sorted(list(ranked_state.keys()))
Expand All @@ -646,7 +665,6 @@ def load_distributed_state_dict(
# On mismatch, discard saved non-reshardable loader state and start fresh
state = base
# Repeat global tensor over all workers
# print(inp["dstate"]["ScalableReader.shard_states"][:,0])
dstate = [inp["dstate"],]*nworkers
# Re-insert worker states into loader state
for i in range(nworkers):
Expand Down

0 comments on commit 71b78dc

Please sign in to comment.