1
- import numpy as np
2
- from numpy .testing import assert_array_equal
1
+ import torch
3
2
4
3
from outlines .text .generate .continuation import Continuation , continuation
5
4
@@ -9,35 +8,55 @@ class Tokenizer:
9
8
eos_token_id = 0
10
9
pad_token_id = - 1
11
10
11
+ def decode (self , token_ids ):
12
+ return ["Test" ] * token_ids .shape [0 ]
13
+
12
14
13
15
class Model :
14
16
tokenizer = Tokenizer ()
15
17
device = "cpu"
16
18
17
19
18
- def test_continuation_is_finished ():
19
- model = continuation (Model (), 10 )
20
+ def test_continuation_eos_is_finished ():
21
+ model = continuation (Model ())
20
22
assert isinstance (model , Continuation )
21
23
22
- token_ids = np . array ([[3 , 2 ]])
24
+ token_ids = torch . tensor ([[3 , 2 ]])
23
25
result = model .is_finished (token_ids )
24
- assert_array_equal (result , [False ])
26
+ assert torch . equal (result , torch . tensor ( [False ]) )
25
27
26
- token_ids = np . array ([[3 , 2 , 0 ]])
28
+ token_ids = torch . tensor ([[3 , 2 , 0 ]])
27
29
result = model .is_finished (token_ids )
28
- assert_array_equal (result , [True ])
30
+ assert torch . equal (result , torch . tensor ( [True ]) )
29
31
30
- token_ids = np . array ([[3 , 2 , 1 ], [3 , 2 , 0 ]])
32
+ token_ids = torch . tensor ([[3 , 2 , 1 ], [3 , 2 , 0 ]])
31
33
result = model .is_finished (token_ids )
32
- assert_array_equal (result , [False , True ])
34
+ assert torch . equal (result , torch . tensor ( [False , True ]) )
33
35
34
- token_ids = np . array ([[3 , 2 , 1 , 0 ], [3 , 2 , 0 , - 1 ]])
36
+ token_ids = torch . tensor ([[3 , 2 , 1 , 0 ], [3 , 2 , 0 , - 1 ]])
35
37
result = model .is_finished (token_ids )
36
- assert_array_equal (result , [True , False ])
38
+ assert torch . equal (result , torch . tensor ( [True , False ]) )
37
39
38
40
39
41
def test_continuation_postprocess ():
40
42
model = continuation (Model ())
41
43
result = model .postprocess_completions (["Here<EOS>" ])
42
44
assert len (result ) == 1
43
45
assert result [0 ] == "Here"
46
+
47
+
48
+ def test_continuation_stop_is_finished ():
49
+ tokenizer = Tokenizer ()
50
+ tokenizer .decode = lambda x : ["finished \n " , "not_finished" ]
51
+ model = Model ()
52
+ model .tokenizer = tokenizer
53
+
54
+ model = continuation (model , stop = ["\n " ])
55
+
56
+ token_ids = torch .tensor ([[2 , 3 ]])
57
+ result = model .is_finished (token_ids )
58
+ assert torch .equal (result , torch .tensor ([True , False ]))
59
+
60
+
61
+ def test_continuation_stop_postprocess ():
62
+ assert False
0 commit comments