Skip to content

Commit 5577c1b

Browse files
authored
Merge branch 'main' into main
2 parents 32ff7ea + 47e1916 commit 5577c1b

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/axolotl/monkeypatch/unsloth_.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def get_forward_code() -> str:
8080
return forward
8181

8282

83-
def test_cel_is_patchable() -> bool:
83+
def check_cel_is_patchable() -> bool:
8484
forward = get_forward_code()
85+
forward, _ = detab_code(forward)
8586
return ORIGINAL_CEL_CODE in forward
8687

8788

@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
9091
return forward
9192

9293

93-
def test_self_attn_is_patchable() -> bool:
94+
def check_self_attn_is_patchable() -> bool:
9495
qkv = get_self_attn_code()
95-
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
96+
qkv, _ = detab_code(qkv)
97+
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
9698

9799

98100
def integrate_cross_entropy_loss_patch():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
2+
import unittest
3+
4+
from axolotl.monkeypatch.unsloth_ import (
5+
check_cel_is_patchable,
6+
check_self_attn_is_patchable,
7+
)
8+
9+
10+
class TestUnslothIntegration(unittest.TestCase):
11+
"""Unsloth monkeypatch integration tests."""
12+
13+
def test_is_cel_patchable(self):
14+
# ensures the current version of transformers has loss code that matches our patching code
15+
self.assertTrue(
16+
check_cel_is_patchable(),
17+
"HF transformers loss code has changed and isn't patchable",
18+
)
19+
20+
def test_is_self_attn_patchable(self):
21+
# ensures the current version of transformers has loss code that matches our patching code
22+
self.assertTrue(
23+
check_self_attn_is_patchable(),
24+
"HF transformers self attention code has changed and isn't patchable",
25+
)

0 commit comments

Comments
 (0)