Skip to content

Commit

Permalink
[BugFix] Consolidate lazy stacks of non-tensors
Browse files Browse the repository at this point in the history
ghstack-source-id: afb1480da5702ec582d4c8438ce16e569b819d9b
Pull Request resolved: #1222
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 7b0fd93 commit 0b901a7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5043,7 +5043,8 @@ def assign(
cls = type(value)
if issubclass(cls, torch.Tensor):
pass
elif _is_non_tensor(cls):
# We want to skip NonTensorStacks
elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase):
if requires_metadata:
metadata_dict["non_tensors"][key] = (
value.data,
Expand Down Expand Up @@ -5411,7 +5412,8 @@ def _view_and_pad(tensor):
if non_blocking and device.type != "cuda":
# sync if needed
self._sync_all()
torch.cat(items, out=storage)
if items:
torch.cat(items, out=storage)
for v, (k, oldv) in _zip_strict(
storage.split(flat_size), list(flat_dict.items())
):
Expand Down
11 changes: 11 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11404,6 +11404,17 @@ def test_stack(self, non_tensor_data):
LazyStackedTensorDict,
)

def test_stack_consolidate(self):
td = torch.stack(
[
TensorDict(a="a string", b="b string"),
TensorDict(a="another string", b="bnother string"),
]
)
tdc = td.consolidate()
assert (tdc == td).all()
assert tdc["a"] == ["a string", "another string"]

def test_assign_non_tensor(self):
data = TensorDict({}, [1, 10])

Expand Down

0 comments on commit 0b901a7

Please sign in to comment.