1
- import numpy as np
2
- from numpy .testing import assert_array_equal
1
+ import torch
3
2
4
3
from outlines .text .generate .continuation import Continuation , continuation
5
4
@@ -9,35 +8,87 @@ class Tokenizer:
9
8
eos_token_id = 0
10
9
pad_token_id = - 1
11
10
11
+ def decode (self , token_ids ):
12
+ return ["Test" ] * token_ids .shape [0 ]
13
+
12
14
13
15
class Model :
14
16
tokenizer = Tokenizer ()
15
17
device = "cpu"
16
18
17
19
18
- def test_continuation_is_finished ():
19
- model = continuation (Model (), 10 )
20
+ def test_continuation_eos_is_finished ():
21
+ model = continuation (Model ())
20
22
assert isinstance (model , Continuation )
21
23
22
- token_ids = np . array ([[3 , 2 ]])
24
+ token_ids = torch . tensor ([[3 , 2 ]])
23
25
result = model .is_finished (token_ids )
24
- assert_array_equal (result , [False ])
26
+ assert torch . equal (result , torch . tensor ( [False ]) )
25
27
26
- token_ids = np . array ([[3 , 2 , 0 ]])
28
+ token_ids = torch . tensor ([[3 , 2 , 0 ]])
27
29
result = model .is_finished (token_ids )
28
- assert_array_equal (result , [True ])
30
+ assert torch . equal (result , torch . tensor ( [True ]) )
29
31
30
- token_ids = np . array ([[3 , 2 , 1 ], [3 , 2 , 0 ]])
32
+ token_ids = torch . tensor ([[3 , 2 , 1 ], [3 , 2 , 0 ]])
31
33
result = model .is_finished (token_ids )
32
- assert_array_equal (result , [False , True ])
34
+ assert torch . equal (result , torch . tensor ( [False , True ]) )
33
35
34
- token_ids = np . array ([[3 , 2 , 1 , 0 ], [3 , 2 , 0 , - 1 ]])
36
+ token_ids = torch . tensor ([[3 , 2 , 1 , 0 ], [3 , 2 , 0 , - 1 ]])
35
37
result = model .is_finished (token_ids )
36
- assert_array_equal (result , [True , False ])
38
+ assert torch . equal (result , torch . tensor ( [True , False ]) )
37
39
38
40
39
41
def test_continuation_postprocess ():
40
42
model = continuation (Model ())
41
43
result = model .postprocess_completions (["Here<EOS>" ])
42
44
assert len (result ) == 1
43
45
assert result [0 ] == "Here"
46
+
47
+
48
+ def test_continuation_stop_is_finished ():
49
+ tokenizer = Tokenizer ()
50
+ tokenizer .decode = lambda x : ["finished \n " , "not_finished" ]
51
+ model = Model ()
52
+ model .tokenizer = tokenizer
53
+
54
+ model = continuation (model , stop = ["\n " ])
55
+
56
+ token_ids = torch .tensor ([[2 , 3 ]])
57
+ result = model .is_finished (token_ids )
58
+ assert torch .equal (result , torch .tensor ([True , False ]))
59
+
60
+
61
+ def test_continuation_stop_postprocess ():
62
+ model = Continuation (Model (), stop = "\n " )
63
+ result = model .postprocess_completions (["Stop\n " ])
64
+ assert len (result ) == 1
65
+ assert result [0 ] == "Stop"
66
+
67
+ model = Continuation (Model (), stop = ["\n " , "," ])
68
+ result = model .postprocess_completions (["Stop" ])
69
+ assert len (result ) == 1
70
+ assert result [0 ] == "Stop"
71
+
72
+ result = model .postprocess_completions (["Stop\n " ])
73
+ assert len (result ) == 1
74
+ assert result [0 ] == "Stop"
75
+
76
+ result = model .postprocess_completions (["Stop\n aaa" ])
77
+ assert len (result ) == 1
78
+ assert result [0 ] == "Stop"
79
+
80
+ result = model .postprocess_completions (["Stop,aa\n aaa" ])
81
+ assert len (result ) == 1
82
+ assert result [0 ] == "Stop"
83
+
84
+ result = model .postprocess_completions (["Stop\n aa,a" ])
85
+ assert len (result ) == 1
86
+ assert result [0 ] == "Stop"
87
+
88
+ result = model .postprocess_completions (["Stop\n " , "Nonstop" ])
89
+ assert len (result ) == 2
90
+ assert result == ["Stop" , "Nonstop" ]
91
+
92
+ result = model .postprocess_completions (["StopHere\n NoHere<EOS>" ])
93
+ assert len (result ) == 1
94
+ assert result [0 ] == "StopHere"
0 commit comments