Skip to content

Commit 4aad8e3

Browse files
committed
Reduce memory overhead when capturing tensors
1 parent abc2879 commit 4aad8e3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

llm_steer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def _add_steer_vector(self, layer_idx: int, steerElem: SteerElement):
227227

228228
def _capture_tensor(self, layer_idx: int, tokens: Tensor):
229229
self._set_forward_fn(ActivationMode.CAPTURE, layer_idx)
230-
self.model(tokens)
230+
with torch.inference_mode():
231+
self.model(tokens)
231232
result = self.captured_tensor
232233
print(f"captured tensor: {result}")
233234
return result

0 commit comments

Comments
 (0)