Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 383488c commit a0a9486
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,6 @@
import pytest
import tensordict.utils
import torch
from tensordict import TensorClass
from tensordict.tensorclass import from_dataclass

try:
import torchsnapshot

_has_torchsnapshot = True
TORCHSNAPSHOT_ERR = ""
except ImportError as err:
_has_torchsnapshot = False
TORCHSNAPSHOT_ERR = str(err)

from _utils_internal import get_available_devices

Expand All @@ -44,14 +33,26 @@
lazy_legacy,
LazyStackedTensorDict,
MemoryMappedTensor,
set_capture_non_tensor_stack,
tensorclass,
TensorClass,
TensorDict,
TensorDictBase,
)
from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict
from tensordict.base import _GENERIC_NESTED_ERR
from tensordict.tensorclass import from_dataclass
from torch import Tensor

try:
import torchsnapshot

_has_torchsnapshot = True
TORCHSNAPSHOT_ERR = ""
except ImportError as err:
_has_torchsnapshot = False
TORCHSNAPSHOT_ERR = str(err)

# Capture all warnings
pytestmark = [
pytest.mark.filterwarnings("error"),
Expand Down Expand Up @@ -381,7 +382,8 @@ class MyData:
data3 = MyData(D, B, A, C=C, E=E, batch_size=[3, 4])
data4 = MyData(D, B, A, C, E=E, batch_size=[3, 4])
data5 = MyData(D, B, A, C, E, batch_size=[3, 4])
data = torch.stack([data1, data2, data3, data4, data5], 0)
with set_capture_non_tensor_stack(True):
data = torch.stack([data1, data2, data3, data4, data5], 0)
assert (data.A == A).all()
assert (data.B == B).all()
assert (data.C == C).all()
Expand Down Expand Up @@ -1857,7 +1859,8 @@ class MyDataNested:
if lazy:
stacked_tc = LazyStackedTensorDict.lazy_stack([data1, data2], 0)
else:
stacked_tc = torch.stack([data1, data2], 0)
with set_capture_non_tensor_stack(True):
stacked_tc = torch.stack([data1, data2], 0)
assert type(stacked_tc) is type(data1)
assert isinstance(stacked_tc.y, type(data1.y))
assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5])
Expand Down Expand Up @@ -2145,7 +2148,8 @@ def z(self) -> torch.Tensor:
y1 = Y(weakref.ref(obj), batch_size=[1])
y = torch.cat([y0, y1])
assert y.z.shape == torch.Size(())
y = torch.stack([y0, y1])
with set_capture_non_tensor_stack(True):
y = torch.stack([y0, y1])
assert y.z.shape == torch.Size(())


Expand Down Expand Up @@ -2253,9 +2257,13 @@ class TensorClass:
def get_nested(self):
c = self.TensorClass(torch.ones(1), ("a", "b", "c"), "Hello", batch_size=[])

td = torch.stack(
[TensorDict({"t": torch.ones(1), "c": c}, batch_size=[]) for _ in range(3)]
)
with set_capture_non_tensor_stack(True):
td = torch.stack(
[
TensorDict({"t": torch.ones(1), "c": c}, batch_size=[])
for _ in range(3)
]
)
return td

def test_apply(self):
Expand Down

0 comments on commit a0a9486

Please sign in to comment.