Skip to content

Commit dc304b7

Browse files
author
Orbax Authors
committed
Internal change
PiperOrigin-RevId: 738393752
1 parent 671e232 commit dc304b7

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,24 @@ def slice(
178178
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
179179
if not (start < stop).all():
180180
return None
181-
if (value := self.value) is None:
182-
return out
183-
else:
184-
value_fragment = Fragment(
185-
np_index=np.stack([
186-
np.maximum(self.start, np_index[:, 0]),
187-
np.minimum(self.stop, np_index[:, 1]),
188-
np_index[:, 2],
189-
], axis=1)
190-
).offset_by(-self.start)
191-
out_value = value[value_fragment.index or ...]
192-
return dataclasses.replace(out, value=out_value)
181+
return dataclasses.replace(
182+
out, value=self.slice_of_value(np_index)
183+
) if self.value is not None else None
184+
185+
def slice_of_value(
186+
self,
187+
new_np_idx: NpIndex,
188+
) -> np.ndarray:
189+
"""Returns a slice of `value`."""
190+
# This is just a convenient way to construct the required tuple of slices.
191+
f = Fragment(
192+
np_index=np.stack([
193+
np.maximum(self.start, new_np_idx[:, 0]) - self.start,
194+
np.minimum(self.stop, new_np_idx[:, 1]) - self.start,
195+
new_np_idx[:, 2],
196+
], axis=1)
197+
)
198+
return self.value[f.index or ...]
193199

194200

195201
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)