Skip to content

Commit 1f05452

Browse files
committed
Max tokens parameter
1 parent 06cb73a commit 1f05452

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

src/compressa/perf/cli/__main__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def run_experiment_args(args):
2020
num_runners=args.num_runners,
2121
generate_prompts=args.generate_prompts,
2222
num_prompts=args.num_prompts,
23-
prompt_length=args.prompt_length
23+
prompt_length=args.prompt_length,
24+
max_tokens=args.max_tokens
2425
)
2526

2627

@@ -123,6 +124,9 @@ def main():
123124
parser_run.add_argument(
124125
"--prompt_length", type=int, default=100, help="Length of each generated prompt (if --generate_prompts is used)"
125126
)
127+
parser_run.add_argument(
128+
"--max_tokens", type=int, default=1000, help="Maximum number of tokens for the model to generate"
129+
)
126130
parser_run.set_defaults(func=run_experiment_args)
127131

128132
parser_report = subparsers.add_parser(

src/compressa/perf/cli/tools.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def generate_random_text(length):
5151
word = ''.join(random.choice(string.ascii_lowercase) for _ in range(word_length))
5252
words.append(word)
5353
current_length += len(word) + 1
54+
55+
words.append(". Repeat this text at least 10 times. Number the repetitions.")
5456
return ' '.join(words)[:length]
5557

5658
def generate_prompts_list(num_prompts, prompt_length):
@@ -73,7 +75,8 @@ def run_experiment(
7375
num_runners: int = 10,
7476
generate_prompts: bool = False,
7577
num_prompts: int = 100,
76-
prompt_length: int = 100
78+
prompt_length: int = 100,
79+
max_tokens: int = 1000,
7780
):
7881
if not openai_api_key:
7982
raise ValueError("OPENAI_API_KEY is not set")
@@ -85,7 +88,7 @@ def run_experiment(
8588
openai_api_key=openai_api_key,
8689
openai_url=openai_url,
8790
model_name=model_name,
88-
num_runners=num_runners
91+
num_runners=num_runners,
8992
)
9093

9194
experiment = Experiment(
@@ -106,7 +109,8 @@ def run_experiment(
106109
experiment_runner.run_experiment(
107110
experiment_id=experiment.id,
108111
prompts=prompts,
109-
num_tasks=num_tasks
112+
num_tasks=num_tasks,
113+
max_tokens=max_tokens,
110114
)
111115

112116
# Run analysis after the experiment

src/compressa/perf/experiment/inference.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def run_inference(
3838
self,
3939
experiment_id: int,
4040
prompt: str,
41-
max_tokens: int = 1000,
41+
max_tokens: int,
4242
):
4343
start_time = time.time()
4444

@@ -107,6 +107,7 @@ def store_experiment_parameters(
107107
self,
108108
experiment_id: int,
109109
num_tasks: int,
110+
max_tokens: int,
110111
):
111112
parameters = [
112113
Parameter(
@@ -127,6 +128,12 @@ def store_experiment_parameters(
127128
key="openai_url",
128129
value=self.openai_url,
129130
),
131+
Parameter(
132+
id=None,
133+
experiment_id=experiment_id,
134+
key="max_tokens",
135+
value=str(max_tokens),
136+
),
130137
]
131138
for param in parameters:
132139
insert_parameter(self.conn, param)
@@ -136,6 +143,7 @@ def run_experiment(
136143
experiment_id: int,
137144
prompts: List[str],
138145
num_tasks: int = 100,
146+
max_tokens: int = 1000,
139147
):
140148
all_measurements = []
141149
with ThreadPoolExecutor(max_workers=self.num_runners) as executor:
@@ -149,7 +157,7 @@ def run_experiment(
149157
for _ in range(self.num_runners)
150158
]
151159
futures = [
152-
executor.submit(runners[i % self.num_runners].run_inference, experiment_id, random.choice(prompts))
160+
executor.submit(runners[i % self.num_runners].run_inference, experiment_id, random.choice(prompts), max_tokens)
153161
for i in range(num_tasks)
154162
]
155163
for future in as_completed(futures):
@@ -160,6 +168,10 @@ def run_experiment(
160168
except Exception as e:
161169
logger.error(f"Task failed: {e}")
162170

163-
self.store_experiment_parameters(experiment_id, num_tasks)
171+
self.store_experiment_parameters(
172+
experiment_id,
173+
num_tasks,
174+
max_tokens,
175+
)
164176
for measurement in all_measurements:
165177
insert_measurement(self.conn, measurement)

0 commit comments

Comments
 (0)