-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_base_flow.py
155 lines (119 loc) · 3.77 KB
/
test_base_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import copy
import pytest
from minds.client import Client
import logging
logging.basicConfig(level=logging.DEBUG)
from minds.datasources.examples import example_ds
from minds.exceptions import ObjectNotFound
def get_client():
api_key = os.getenv('API_KEY')
base_url = os.getenv('BASE_URL', 'https://dev.mindsdb.com')
return Client(api_key, base_url=base_url)
def test_wrong_api_key():
base_url = 'https://dev.mindsdb.com'
client = Client('api_key', base_url=base_url)
with pytest.raises(Exception):
client.datasources.get('example_db')
def test_datasources():
client = get_client()
# remove previous object
try:
client.datasources.drop(example_ds.name, force=True)
except ObjectNotFound:
...
# create
ds = client.datasources.create(example_ds)
ds = client.datasources.create(example_ds, replace=True)
assert ds.name == example_ds.name
# get
ds = client.datasources.get(example_ds.name)
# list
ds_list = client.datasources.list()
assert len(ds_list) > 0
# drop
client.datasources.drop(ds.name)
def test_minds():
client = get_client()
ds_name = 'test_datasource_'
ds_name2 = 'test_datasource2_'
mind_name = 'int_test_mind_'
mind_name2 = 'int_test_mind2_'
prompt1 = 'answer in german'
prompt2 = 'answer in spanish'
# remove previous objects
for name in (mind_name, mind_name2):
try:
client.minds.drop(name)
except ObjectNotFound:
...
# prepare datasources
ds_cfg = copy.copy(example_ds)
ds_cfg.name = ds_name
ds = client.datasources.create(example_ds, replace=True)
# second datasource
ds2_cfg = copy.copy(example_ds)
ds2_cfg.name = ds_name2
ds2_cfg.tables = ['home_rentals']
# create
mind = client.minds.create(
mind_name,
datasources=[ds],
provider='openai'
)
mind = client.minds.create(
mind_name,
replace=True,
datasources=[ds.name, ds2_cfg],
prompt_template=prompt1
)
# get
mind = client.minds.get(mind_name)
assert len(mind.datasources) == 2
assert mind.prompt_template == prompt1
# list
mind_list = client.minds.list()
assert len(mind_list) > 0
# rename & update
mind.update(
name=mind_name2,
datasources=[ds.name],
prompt_template=prompt2
)
with pytest.raises(ObjectNotFound):
# this name not exists
client.minds.get(mind_name)
mind = client.minds.get(mind_name2)
assert len(mind.datasources) == 1
assert mind.prompt_template == prompt2
# add datasource
mind.add_datasource(ds2_cfg)
assert len(mind.datasources) == 2
# del datasource
mind.del_datasource(ds2_cfg.name)
assert len(mind.datasources) == 1
# completion
answer = mind.completion('say hello')
assert 'hola' in answer.lower()
# ask about data
answer = mind.completion('what is max rental price in home rental?')
assert '5602' in answer.replace(' ', '').replace(',', '')
# limit tables
mind.del_datasource(ds.name)
mind.add_datasource(ds_name2)
assert len(mind.datasources) == 1
answer = mind.completion('what is max rental price in home rental?')
assert '5602' in answer.replace(' ', '').replace(',', '')
# not accessible table
answer = mind.completion('what is max price in car sales?')
assert '145000' not in answer.replace(' ', '').replace(',', '')
# stream completion
success = False
for chunk in mind.completion('say hello', stream=True):
if 'hola' in chunk.content.lower():
success = True
assert success is True
# drop
client.minds.drop(mind_name2)
client.datasources.drop(ds.name)
client.datasources.drop(ds2_cfg.name)