Skip to content

Commit

Permalink
[Feature] capture_non_tensor_stack
Browse files Browse the repository at this point in the history
ghstack-source-id: 27805b68d4663d51f4ecd67f0495de8f83c90c41
Pull Request resolved: #1221
  • Loading branch information
vmoens committed Feb 19, 2025
1 parent 9865f4f commit 7357e0c
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
8 changes: 5 additions & 3 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,15 @@ Utils
utils.expand_right
utils.isin
utils.remove_duplicates
capture_non_tensor_stack
dense_stack_tds
is_batchedtensor
is_tensor_collection
lazy_legacy
make_tensordict
merge_tensordicts
pad
pad_sequence
dense_stack_tds
set_lazy_legacy
lazy_legacy
parse_tensor_dict_string
set_capture_non_tensor_stack
set_lazy_legacy
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@
from tensordict.utils import (
assert_allclose_td,
assert_close,
capture_non_tensor_stack,
is_batchedtensor,
is_non_tensor,
is_tensorclass,
lazy_legacy,
parse_tensor_dict_string,
set_capture_non_tensor_stack,
set_lazy_legacy,
unravel_key,
unravel_key_list,
Expand Down
15 changes: 13 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_TENSORCLASS_MEMO,
_unravel_key_to_tuple,
_zip_strict,
capture_non_tensor_stack,
DeviceType,
IndexType,
is_tensorclass,
Expand Down Expand Up @@ -3219,7 +3220,7 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)

ids = set()
firstdata = NO_DEFAULT
return_stack = False
return_stack = capture_non_tensor_stack(allow_none=True)
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
if raise_if_non_unique:
Expand All @@ -3242,8 +3243,18 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False)
return_stack = True
break
else:
return_stack = False
return_stack = capture_non_tensor_stack(allow_none=True)
if not return_stack:
if return_stack is None:
warnings.warn(
"The default behavior of stacking non-tensor data will change in "
"version v0.9 and switch from True to False (current default). "
"To prepare for this change, use set_capture_non_tensor_stack(val: bool) as a decorator or context "
"manager, or set the environment variable CAPTURE_NONTENSOR_STACK "
"to 'False'.",
FutureWarning,
stacklevel=2,
)
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
Expand Down
83 changes: 83 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,89 @@ def _legacy_lazy(func):
return func


# non tensor stack control
_DEFAULT_CAPTURE_NONTENSOR_STACK = False
_CAPTURE_NONTENSOR_STACK = os.environ.get("CAPTURE_NONTENSOR_STACK")


class set_capture_non_tensor_stack(_DecoratorContextManager):
"""A context manager or decorator to control whether identical non-tensor data should be stacked into a single NonTensorData object or a NonTensorStack.
Args:
mode (bool): Whether to capture non-tensor stacks. If ``True``, identical
non-tensor data will be stacked into a :class:`~tensordict.NonTensorStack`. If ``False``,
a single NonTensorData object will contain the unique value, but with the desired batch-size.
Defaults to ``False``.
.. note:: Until v0.9, this will raise a warning if the same value is encountered and the value is not set
explicitly. You can set the value of :func:`~tensordict.capture_non_tensor_stack` through:
- The ``CAPTURE_NON_TENSOR_STACK`` environment variable;
- By setting ``set_capture_non_tensor_stack(val: bool).set()`` at the beginning of your script;
- By using ``set_capture_non_tensor_stack(val: bool)`` as a context manager or a decorator.
.. seealso:: :class:`~tensordict.capture_non_tensor_stack`
Examples:
>>> with set_capture_non_tensor_stack(False):
... torch.stack([NonTensorData("a"), NonTensorData("a")])
NonTensorData("a", batch_size=[2])
>>> @set_capture_non_tensor_stack(True)
... def my_function():
... return torch.stack([NonTensorData("a"), NonTensorData("a")])
>>> my_function()
NonTensorStack(["a", "a"], stack_dim=0)
"""

def __init__(self, mode: bool) -> None:
super().__init__()
self.mode = mode

def clone(self) -> set_capture_non_tensor_stack:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
self.set()

def set(self) -> None:
global _CAPTURE_NONTENSOR_STACK
self._old_mode = _CAPTURE_NONTENSOR_STACK
_CAPTURE_NONTENSOR_STACK = bool(self.mode)
# we do this such that sub-processes see the same lazy op than the main one
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _CAPTURE_NONTENSOR_STACK
_CAPTURE_NONTENSOR_STACK = bool(self._old_mode)
os.environ["CAPTURE_NONTENSOR_STACK"] = str(_CAPTURE_NONTENSOR_STACK)


def capture_non_tensor_stack(allow_none=False):
"""Get the current setting for capturing non-tensor stacks.
Args:
allow_none (bool, optional): If ``True``, returns ``None`` if no setting has been
specified. Otherwise, returns the default setting. Defaults to ``False``.
seealso: :func:`~tensordict.set_capture_non_tensor_stack`
Returns:
bool or None: The current setting for capturing non-tensor stacks.
"""
global _CAPTURE_NONTENSOR_STACK
if _CAPTURE_NONTENSOR_STACK is None and allow_none:
return None
elif _CAPTURE_NONTENSOR_STACK is None:
return _DEFAULT_CAPTURE_NONTENSOR_STACK
return (
strtobool(_CAPTURE_NONTENSOR_STACK)
if isinstance(_CAPTURE_NONTENSOR_STACK, str)
else _CAPTURE_NONTENSOR_STACK
)


# Process initializer for map
def _proc_init(base_seed, queue, num_threads):
worker_id = queue.get(timeout=120)
Expand Down

0 comments on commit 7357e0c

Please sign in to comment.