Skip to content

Commit daf9537

Browse files
authored
feat: Change the typing for the AIConfig (#16)
This commit inlines the `AIConfigData` class into the existing `AIConfig`. It also introduces a new type, `ModelConfig` to replace the much looser dictionary config that was originally in play. This new model contains specific, typed properties for the model id, temperature, and max tokens. Additional model-specific attributes can be accessed as well.
1 parent c752739 commit daf9537

File tree

2 files changed

+132
-42
lines changed

2 files changed

+132
-42
lines changed

ldai/client.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,69 @@ class LDMessage:
1616
content: str
1717

1818

19-
@dataclass
20-
class AIConfigData:
21-
model: Optional[dict]
22-
prompt: Optional[List[LDMessage]]
19+
class ModelConfig:
20+
"""
21+
Configuration related to the model.
22+
"""
23+
24+
def __init__(self, id: str, temperature: Optional[float] = None,
25+
max_tokens: Optional[int] = None, attributes: dict = {}):
26+
"""
27+
:param id: The ID of the model.
28+
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
29+
:param max_tokens: The maximum number of tokens.
30+
:param attributes: Additional model-specific attributes.
31+
"""
32+
self._id = id
33+
self._temperature = temperature
34+
self._max_tokens = max_tokens
35+
self._attributes = attributes
36+
37+
@property
38+
def id(self) -> str:
39+
"""
40+
The ID of the model.
41+
"""
42+
return self._id
43+
44+
@property
45+
def temperature(self) -> Optional[float]:
46+
""""
47+
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
48+
"""
49+
return self._temperature
50+
51+
@property
52+
def max_tokens(self) -> Optional[int]:
53+
"""
54+
The maximum number of tokens.
55+
"""
56+
57+
return self._max_tokens
58+
59+
def get_attribute(self, key: str) -> Any:
60+
"""
61+
Retrieve model-specific attributes.
62+
63+
Accessing a named, typed attribute (e.g. id) will result in the call
64+
being delegated to the appropriate property.
65+
"""
66+
if key == 'id':
67+
return self.id
68+
if key == 'temperature':
69+
return self.temperature
70+
if key == 'maxTokens':
71+
return self.max_tokens
72+
73+
return self._attributes.get(key)
2374

2475

2576
class AIConfig:
26-
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
27-
self.config = config
77+
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]):
2878
self.tracker = tracker
2979
self.enabled = enabled
80+
self.model = model
81+
self.prompt = prompt
3082

3183

3284
class LDAIClient:
@@ -71,16 +123,26 @@ def model_config(
71123
for entry in variation['prompt']
72124
]
73125

126+
model = None
127+
if 'model' in variation:
128+
model = ModelConfig(
129+
id=variation['model']['modelId'],
130+
temperature=variation['model'].get('temperature'),
131+
max_tokens=variation['model'].get('maxTokens'),
132+
attributes=variation['model'],
133+
)
134+
74135
enabled = variation.get('_ldMeta', {}).get('enabled', False)
75136
return AIConfig(
76-
config=AIConfigData(model=variation['model'], prompt=prompt),
77137
tracker=LDAIConfigTracker(
78138
self.client,
79139
variation.get('_ldMeta', {}).get('versionKey', ''),
80140
key,
81141
context,
82142
),
83143
enabled=bool(enabled),
144+
model=model,
145+
prompt=prompt
84146
)
85147

86148
def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:

ldai/testing/test_model_config.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytest
22
from ldclient import Config, Context, LDClient
33
from ldclient.integrations.test_data import TestData
4-
from ldclient.testing.builders import *
54

6-
from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
5+
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
76
from ldai.tracker import LDAIConfigTracker
87

98

@@ -14,7 +13,7 @@ def td() -> TestData:
1413
td.flag('model-config')
1514
.variations(
1615
{
17-
'model': {'modelId': 'fakeModel'},
16+
'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096},
1817
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
1918
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
2019
},
@@ -27,7 +26,7 @@ def td() -> TestData:
2726
td.flag('multiple-prompt')
2827
.variations(
2928
{
30-
'model': {'modelId': 'fakeModel'},
29+
'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192},
3130
'prompt': [
3231
{'role': 'system', 'content': 'Hello, {{name}}!'},
3332
{'role': 'user', 'content': 'The day is, {{day}}!'},
@@ -43,7 +42,7 @@ def td() -> TestData:
4342
td.flag('ctx-interpolation')
4443
.variations(
4544
{
46-
'model': {'modelId': 'fakeModel'},
45+
'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'},
4746
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
4847
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
4948
}
@@ -55,7 +54,7 @@ def td() -> TestData:
5554
td.flag('off-config')
5655
.variations(
5756
{
58-
'model': {'modelId': 'fakeModel'},
57+
'model': {'modelId': 'fakeModel', 'temperature': 0.1},
5958
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
6059
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
6160
}
@@ -82,81 +81,110 @@ def ldai_client(client: LDClient) -> LDAIClient:
8281
return LDAIClient(client)
8382

8483

84+
def test_model_config_delegates_to_properties():
85+
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
86+
assert model.id == 'fakeModel'
87+
assert model.temperature == 0.5
88+
assert model.max_tokens == 4096
89+
assert model.get_attribute('extra-attribute') == 'value'
90+
assert model.get_attribute('non-existent') is None
91+
92+
assert model.id == model.get_attribute('id')
93+
assert model.temperature == model.get_attribute('temperature')
94+
assert model.max_tokens == model.get_attribute('maxTokens')
95+
assert model.max_tokens != model.get_attribute('max_tokens')
96+
97+
8598
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
8699
context = Context.create('user-key')
87100
default_value = AIConfig(
88-
config=AIConfigData(
89-
model={'modelId': 'fakeModel'},
90-
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
91-
),
92101
tracker=tracker,
93102
enabled=True,
103+
model=ModelConfig('fakeModel'),
104+
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
94105
)
95106
variables = {'name': 'World'}
96107

97108
config = ldai_client.model_config('model-config', context, default_value, variables)
98109

99-
assert config.config.prompt is not None
100-
assert len(config.config.prompt) > 0
101-
assert config.config.prompt[0].content == 'Hello, World!'
110+
assert config.prompt is not None
111+
assert len(config.prompt) > 0
112+
assert config.prompt[0].content == 'Hello, World!'
102113
assert config.enabled is True
103114

115+
assert config.model is not None
116+
assert config.model.id == 'fakeModel'
117+
assert config.model.temperature == 0.5
118+
assert config.model.max_tokens == 4096
119+
104120

105121
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
106122
context = Context.create('user-key')
107-
default_value = AIConfig(
108-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
109-
)
123+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
110124

111125
config = ldai_client.model_config('model-config', context, default_value, {})
112126

113-
assert config.config.prompt is not None
114-
assert len(config.config.prompt) > 0
115-
assert config.config.prompt[0].content == 'Hello, !'
127+
assert config.prompt is not None
128+
assert len(config.prompt) > 0
129+
assert config.prompt[0].content == 'Hello, !'
116130
assert config.enabled is True
117131

132+
assert config.model is not None
133+
assert config.model.id == 'fakeModel'
134+
assert config.model.temperature == 0.5
135+
assert config.model.max_tokens == 4096
136+
118137

119138
def test_context_interpolation(ldai_client: LDAIClient, tracker):
120139
context = Context.builder('user-key').name("Sandy").build()
121-
default_value = AIConfig(
122-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
123-
)
140+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
124141
variables = {'name': 'World'}
125142

126143
config = ldai_client.model_config(
127144
'ctx-interpolation', context, default_value, variables
128145
)
129146

130-
assert config.config.prompt is not None
131-
assert len(config.config.prompt) > 0
132-
assert config.config.prompt[0].content == 'Hello, Sandy!'
147+
assert config.prompt is not None
148+
assert len(config.prompt) > 0
149+
assert config.prompt[0].content == 'Hello, Sandy!'
133150
assert config.enabled is True
134151

152+
assert config.model is not None
153+
assert config.model.id == 'fakeModel'
154+
assert config.model.temperature is None
155+
assert config.model.max_tokens is None
156+
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'
157+
135158

136159
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
137160
context = Context.create('user-key')
138-
default_value = AIConfig(
139-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
140-
)
161+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
141162
variables = {'name': 'World', 'day': 'Monday'}
142163

143164
config = ldai_client.model_config(
144165
'multiple-prompt', context, default_value, variables
145166
)
146167

147-
assert config.config.prompt is not None
148-
assert len(config.config.prompt) > 0
149-
assert config.config.prompt[0].content == 'Hello, World!'
150-
assert config.config.prompt[1].content == 'The day is, Monday!'
168+
assert config.prompt is not None
169+
assert len(config.prompt) > 0
170+
assert config.prompt[0].content == 'Hello, World!'
171+
assert config.prompt[1].content == 'The day is, Monday!'
151172
assert config.enabled is True
152173

174+
assert config.model is not None
175+
assert config.model.id == 'fakeModel'
176+
assert config.model.temperature == 0.7
177+
assert config.model.max_tokens == 8192
178+
153179

154180
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
155181
context = Context.create('user-key')
156-
default_value = AIConfig(
157-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False
158-
)
182+
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[])
159183

160184
config = ldai_client.model_config('off-config', context, default_value, {})
161185

186+
assert config.model is not None
162187
assert config.enabled is False
188+
assert config.model.id == 'fakeModel'
189+
assert config.model.temperature == 0.1
190+
assert config.model.max_tokens is None

0 commit comments

Comments
 (0)