File tree 2 files changed +30
-3
lines changed
2 files changed +30
-3
lines changed Original file line number Diff line number Diff line change @@ -80,8 +80,9 @@ def get_forward_code() -> str:
80
80
return forward
81
81
82
82
83
- def test_cel_is_patchable () -> bool :
83
+ def check_cel_is_patchable () -> bool :
84
84
forward = get_forward_code ()
85
+ forward , _ = detab_code (forward )
85
86
return ORIGINAL_CEL_CODE in forward
86
87
87
88
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
90
91
return forward
91
92
92
93
93
- def test_self_attn_is_patchable () -> bool :
94
+ def check_self_attn_is_patchable () -> bool :
94
95
qkv = get_self_attn_code ()
95
- return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
96
+ qkv , _ = detab_code (qkv )
97
+ return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
96
98
97
99
98
100
def integrate_cross_entropy_loss_patch ():
Original file line number Diff line number Diff line change
1
+ """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
2
+ import unittest
3
+
4
+ from axolotl .monkeypatch .unsloth_ import (
5
+ check_cel_is_patchable ,
6
+ check_self_attn_is_patchable ,
7
+ )
8
+
9
+
10
+ class TestUnslothIntegration (unittest .TestCase ):
11
+ """Unsloth monkeypatch integration tests."""
12
+
13
+ def test_is_cel_patchable (self ):
14
+ # ensures the current version of transformers has loss code that matches our patching code
15
+ self .assertTrue (
16
+ check_cel_is_patchable (),
17
+ "HF transformers loss code has changed and isn't patchable" ,
18
+ )
19
+
20
+ def test_is_self_attn_patchable (self ):
21
+ # ensures the current version of transformers has loss code that matches our patching code
22
+ self .assertTrue (
23
+ check_self_attn_is_patchable (),
24
+ "HF transformers self attention code has changed and isn't patchable" ,
25
+ )
You can’t perform that action at this time.
0 commit comments