Skip to content

Commit a067b1b

Browse files
dont restart container for every execution
1 parent ab2eb0e commit a067b1b

File tree

2 files changed

+142
-148
lines changed

2 files changed

+142
-148
lines changed

src/genesys/verifier/code_test_verifier.py

+131-145
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,159 @@
11
import docker
2-
from pydantic import BaseModel, Field
3-
from typing import List, Dict
42
import io
53
import tarfile
64
import uuid
75
import re
86

9-
class CodeTestsVerification(BaseModel):
10-
type: str = Field("code_tests")
11-
language: str
12-
test_cases: List[Dict]
7+
# We keep a global dictionary for our containers
8+
CONTAINERS = {}
9+
10+
def init_containers():
11+
docker_client = docker.from_env()
1312

14-
def extract_code(response: str) -> str:
15-
code_blocks = re.findall(r'```(?:\w+)?\n(.*?)```', response, re.DOTALL)
13+
CONTAINERS["Python"] = docker_client.containers.run(
14+
"python:3.9",
15+
command="sleep infinity",
16+
detach=True
17+
)
1618

17-
if code_blocks:
18-
return code_blocks[-1].strip()
19-
else:
20-
return None
19+
CONTAINERS["Rust"] = docker_client.containers.run(
20+
"rust:latest",
21+
command="sleep infinity",
22+
detach=True
23+
)
24+
25+
CONTAINERS["C++"] = docker_client.containers.run(
26+
"gcc:latest",
27+
command="sleep infinity",
28+
detach=True
29+
)
30+
31+
CONTAINERS["Javascript"] = docker_client.containers.run(
32+
"node:latest",
33+
command="sleep infinity",
34+
detach=True
35+
)
2136

22-
def copy_to_container(container, dst, content):
37+
def close_containers():
38+
for lang, container in CONTAINERS.items():
39+
container.stop()
40+
container.remove()
41+
CONTAINERS.clear()
42+
43+
def copy_to_container(container, dst, content: str):
2344
data = io.BytesIO()
2445
with tarfile.TarFile(fileobj=data, mode='w') as tf:
2546
tar_info = tarfile.TarInfo(name=dst)
2647
tar_info.size = len(content)
2748
tf.addfile(tar_info, io.BytesIO(content.encode('utf-8')))
2849

2950
data.seek(0)
30-
3151
container.put_archive("/", data)
32-
33-
def execute_python(code, inputs, docker_client):
34-
35-
code_filename = f"code_{uuid.uuid4().hex}.py"
36-
input_filename = f"input_{uuid.uuid4().hex}.txt"
37-
38-
container = docker_client.containers.create(
39-
image="python:3.9",
40-
command="sleep infinity",
41-
tty=False,
42-
stdin_open=False
43-
)
44-
container.start()
45-
46-
copy_to_container(container, code_filename, code)
47-
copy_to_container(container, input_filename, inputs)
48-
49-
run_cmd = ["sh", "-c", f"python {code_filename} < {input_filename}"]
50-
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
51-
output = run_result.output.decode()
52-
53-
container.stop()
54-
container.remove()
55-
56-
return output
5752

58-
def execute_rust(code, inputs, docker_client):
59-
code_filename = f"main_{uuid.uuid4().hex}.rs"
60-
input_filename = f"input_{uuid.uuid4().hex}.txt"
61-
62-
container = docker_client.containers.create(
63-
image="rust:latest",
64-
command="sleep infinity",
65-
tty=False,
66-
stdin_open=False
67-
)
68-
69-
container.start()
53+
def extract_code(response: str) -> str:
54+
code_blocks = re.findall(r'```(?:\w+)?\n(.*?)```', response, re.DOTALL)
55+
if code_blocks:
56+
return code_blocks[-1].strip()
57+
else:
58+
return None
7059

71-
copy_to_container(container, code_filename, code)
72-
copy_to_container(container, input_filename, inputs)
60+
def _verify_compiled_code(container, code, test_cases, language):
61+
if language == "C++":
62+
source_filename = f"main_{uuid.uuid4().hex}.cpp"
63+
compile_cmd = f"g++ {source_filename} -o main"
64+
run_binary = "./main"
65+
elif language == "Rust":
66+
source_filename = f"main_{uuid.uuid4().hex}.rs"
67+
compile_cmd = f"rustc {source_filename} -o main"
68+
run_binary = "./main"
69+
else:
70+
# Shouldn't happen if we call this only for Rust/C++
71+
return 0.0
7372

74-
compile_cmd = f"rustc {code_filename} -o main"
73+
# Copy source
74+
copy_to_container(container, source_filename, code)
75+
# Compile
7576
compile_result = container.exec_run(cmd=compile_cmd, stdout=True, stderr=True)
7677
if compile_result.exit_code != 0:
7778
error_output = compile_result.output.decode()
78-
container.stop()
79-
container.remove()
80-
return f"Compilation Error:\n{error_output}"
79+
print("Compilation Error:\n", error_output)
80+
return 0.0
8181

82-
run_cmd = ["sh", "-c", f"./main < {input_filename}"]
83-
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
84-
output = run_result.output.decode()
82+
passed_tests = 0
83+
total_tests = len(test_cases)
8584

86-
container.stop()
87-
container.remove()
85+
for test in test_cases:
86+
input_filename = f"input_{uuid.uuid4().hex}.txt"
87+
copy_to_container(container, input_filename, test["input"])
8888

89-
return output
89+
run_cmd = ["sh", "-c", f"{run_binary} < {input_filename}"]
90+
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
91+
output = run_result.output.decode()
9092

91-
def execute_cpp(code, inputs, docker_client):
92-
code_filename = f"main_{uuid.uuid4().hex}.cpp"
93-
input_filename = f"input_{uuid.uuid4().hex}.txt"
93+
if output.strip() == test["output"].strip():
94+
passed_tests += 1
9495

95-
container = docker_client.containers.create(
96-
image="gcc:latest",
97-
command="sleep infinity",
98-
tty=False,
99-
stdin_open=False
100-
)
101-
container.start()
96+
return passed_tests / total_tests
97+
98+
def _verify_interpreted_code(container, code, test_cases, language):
99+
"""
100+
Copy code once, then run multiple times with different inputs.
101+
"""
102+
if language == "Python":
103+
code_filename = f"code_{uuid.uuid4().hex}.py"
104+
run_cmd_template = "python {code_file} < {input_file}"
105+
elif language == "Javascript":
106+
code_filename = f"main_{uuid.uuid4().hex}.js"
107+
run_cmd_template = "node {code_file} < {input_file}"
108+
else:
109+
return 0.0
102110

111+
# Copy code once
103112
copy_to_container(container, code_filename, code)
104-
copy_to_container(container, input_filename, inputs)
105-
106-
compile_cmd = f"g++ {code_filename} -o main"
107-
compile_result = container.exec_run(cmd=compile_cmd, stdout=True, stderr=True)
108-
if compile_result.exit_code != 0:
109-
error_output = compile_result.output.decode()
110-
container.stop()
111-
container.remove()
112-
return f"Compilation Error:\n{error_output}"
113-
114-
run_cmd = ["sh", "-c", f"./main < {input_filename}"]
115-
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
116-
output = run_result.output.decode()
117-
118-
container.stop()
119-
container.remove()
120113

121-
return output
122-
123-
def execute_javascript(code, inputs, docker_client):
124-
code_filename = f"main_{uuid.uuid4().hex}.js"
125-
input_filename = f"input_{uuid.uuid4().hex}.txt"
114+
passed_tests = 0
115+
total_tests = len(test_cases)
126116

127-
container = docker_client.containers.create(
128-
image="node:latest",
129-
command="sleep infinity",
130-
tty=False,
131-
stdin_open=False
132-
)
133-
container.start()
117+
for test in test_cases:
118+
input_filename = f"input_{uuid.uuid4().hex}.txt"
119+
copy_to_container(container, input_filename, test["input"])
134120

135-
copy_to_container(container, code_filename, code)
136-
copy_to_container(container, input_filename, inputs)
121+
run_cmd_str = run_cmd_template.format(
122+
code_file=code_filename,
123+
input_file=input_filename
124+
)
125+
run_cmd = ["sh", "-c", run_cmd_str]
126+
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
127+
output = run_result.output.decode()
137128

138-
run_cmd = ["sh", "-c", f"node {code_filename} < {input_filename}"]
139-
run_result = container.exec_run(cmd=run_cmd, stdout=True, stderr=True)
140-
output = run_result.output.decode()
129+
if output.strip() == test["output"].strip():
130+
passed_tests += 1
141131

142-
container.stop()
143-
container.remove()
132+
return passed_tests / total_tests
144133

145-
return output
146-
147134
def verify_code(response: str, test_cases, language):
148135
code = extract_code(response)
149136
if code is None:
150-
return 0
151-
152-
docker_client = docker.from_env()
153-
154-
language_executors = {
155-
"Python": execute_python,
156-
"Rust": execute_rust,
157-
"C++": execute_cpp,
158-
"Javascript": execute_javascript,
159-
}
160-
161-
executor = language_executors.get(language)
162-
163-
passed_tests = 0
164-
total_tests = len(test_cases)
165-
166-
for test in test_cases:
167-
output = executor(code, test["input"], docker_client)
168-
print("output", output)
169-
print("input", test["input"])
170-
if output.strip() == test["output"].strip():
171-
passed_tests += 1
172-
173-
return passed_tests/total_tests
174-
175-
176-
137+
return 0.0
138+
139+
if language not in CONTAINERS:
140+
print(f"No container found for language: {language}")
141+
return 0.0
142+
143+
container = CONTAINERS[language]
144+
145+
if language in ["C++", "Rust"]:
146+
return _verify_compiled_code(container, code, test_cases, language)
147+
elif language in ["Python", "Javascript"]:
148+
return _verify_interpreted_code(container, code, test_cases, language)
149+
else:
150+
print("Unsupported language:", language)
151+
return 0.0
152+
153+
177154
if __name__ == '__main__':
155+
init_containers()
156+
178157
code_samples = [
179158
"""
180159
Here's a Python solution to the problem:
@@ -384,21 +363,28 @@ def verify_code(response: str, test_cases, language):
384363
385364
This JavaScript implementation should work correctly for the given problem.
386365
"""
387-
]
366+
]*10
388367

389368
test_cases = [
390-
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}],
391-
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}],
392-
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}],
393-
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}]
394-
]
369+
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}]*10,
370+
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}]*10,
371+
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}]*10,
372+
[{"input": "3\n2 2 3\n4 3 7\n10 1 9\n", "output": "1\n6\n-1\n"}]*10
373+
]*10
395374

396-
languages = ["Python", "C++", "Rust", "Javascript"]
375+
languages = ["Python", "C++", "Rust", "Javascript"]*25
397376

377+
import time
378+
379+
start = time.time()
398380
for c, t, l in zip(code_samples, test_cases, languages):
399381
print(f"\n\n### Testing for {l} ###")
400382
score = verify_code(c, t, l)
401383
print(score)
384+
end = time.time()
402385

386+
print("time", end-start)
387+
388+
close_containers()
403389

404-
390+

src/genesys/verifier/verifier.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import List, Union
22
from pydantic import BaseModel, Field
3-
from .code_test_verifier import CodeTestsVerification, verify_code
4-
from .math_verifier import MathGroundTruthVerification, verify_math
3+
from genesys.verifier.code_test_verifier import CodeTestsVerification, verify_code, init_containers, close_containers
4+
from genesys.verifier.math_verifier import MathGroundTruthVerification, verify_math
55

66
VerificationInfo = Union[MathGroundTruthVerification, CodeTestsVerification]
77

88
def verify(instructions: List[str], responses: List[str], verification_data: List[VerificationInfo]):
9+
10+
has_code_tests = any(verification.type == "code_tests" for verification in verification_data)
11+
if has_code_tests:
12+
init_containers()
13+
914
scores = []
1015
for instruction, response, verification in zip(instructions, responses, verification_data):
1116
if verification.type == "code_tests":
@@ -18,5 +23,8 @@ def verify(instructions: List[str], responses: List[str], verification_data: Lis
1823
raise ValueError(f"Unknown verification type: {verification}")
1924

2025
scores.append(result)
21-
26+
27+
if has_code_tests:
28+
close_containers()
29+
2230
return scores

0 commit comments

Comments
 (0)