Skip to content

Commit

Permalink
graders: add types to CommunicationGrader
Browse files Browse the repository at this point in the history
  • Loading branch information
hieplpvip committed Dec 29, 2023
1 parent 7a18153 commit d3c8854
Showing 1 changed file with 42 additions and 20 deletions.
62 changes: 42 additions & 20 deletions dmoj/graders/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,32 @@
import subprocess
import tempfile
import uuid
from typing import List, TYPE_CHECKING

from dmoj.checkers import CheckerOutput
from dmoj.contrib import contrib_modules
from dmoj.cptbox import TracedPopen
from dmoj.cptbox.filesystem_policies import RecursiveDir
from dmoj.error import InternalError
from dmoj.executors import executors
from dmoj.executors.base_executor import BaseExecutor
from dmoj.graders.standard import StandardGrader
from dmoj.judgeenv import env, get_problem_root
from dmoj.problem import Problem, TestCase
from dmoj.result import Result
from dmoj.utils.helper_files import compile_with_auxiliary_files
from dmoj.utils.unicode import utf8bytes, utf8text

if TYPE_CHECKING:
from dmoj.judge import JudgeWorker


STDIN_FD_FLAGS = os.O_RDONLY
STDOUT_FD_FLAGS = os.O_WRONLY | os.O_TRUNC | os.O_CREAT
STDOUT_FD_MODE = stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH | stat.S_IWUSR


def merge_results(first_result, second_result):
def merge_results(first_result: Result, second_result: Result):
"""
Merge second_result into first_result and return first_result
Expand All @@ -41,7 +50,14 @@ def merge_results(first_result, second_result):


class CommunicationGrader(StandardGrader):
def __init__(self, judge, problem, language, source):
_fifo_dir: List[str]
_fifo_user_to_manager: List[str]
_fifo_manager_to_user: List[str]
_manager_proc: TracedPopen
_user_procs: List[TracedPopen]
_user_results: List[Result]

def __init__(self, judge: 'JudgeWorker', problem: Problem, language: str, source: bytes) -> None:
super().__init__(judge, problem, language, source)

self.handler_data = self.problem.config.communication
Expand All @@ -52,15 +68,16 @@ def __init__(self, judge, problem, language, source):
if self.contrib_type not in contrib_modules:
raise InternalError('%s is not a valid contrib module' % self.contrib_type)

self.num_processes = self.handler_data.get('num_processes', 1)
self.num_processes = int(self.handler_data.get('num_processes', 1))
if self.num_processes < 1:
raise InternalError('num_processes must be positive')

self.manager_binary = self._generate_manager_binary()

def populate_result(self, error, result, process):
def populate_result(self, error: bytes, result: Result, process: TracedPopen) -> None:
for i in range(self.num_processes):
_user_proc, _user_result = self._user_procs[i], self._user_results[i]
assert _user_proc.stderr is not None
self.binary.populate_result(_user_proc.stderr.read(), _user_result, _user_proc)
result = merge_results(result, _user_result)

Expand All @@ -70,7 +87,7 @@ def populate_result(self, error, result, process):
if result.execution_time > self.problem.time_limit:
result.result_flag |= Result.TLE

def check_result(self, case, result):
def check_result(self, case: TestCase, result: Result) -> CheckerOutput:
if (case.config['checker'] or 'standard') != 'standard':
return super().check_result(case, result)

Expand All @@ -89,7 +106,7 @@ def check_result(self, case, result):
stderr=self._manager_stderr,
)

def _launch_process(self, case):
def _launch_process(self, case: TestCase) -> None:
# Indices for the objects related to each user process
indices = range(self.num_processes)

Expand Down Expand Up @@ -126,28 +143,31 @@ def _launch_process(self, case):
)

# Create user processes
self._user_procs = [None for i in indices]
self._user_results = [Result(case) for i in indices]
self._user_procs = []
self._user_results = []
for i in indices:
# Setup std*** redirection
stdin_fd = os.open(self._fifo_manager_to_user[i], STDIN_FD_FLAGS)
stdout_fd = os.open(self._fifo_user_to_manager[i], STDOUT_FD_FLAGS, STDOUT_FD_MODE)

self._user_procs[i] = self.binary.launch(
time=self.problem.time_limit,
memory=self.problem.memory_limit,
symlinks=case.config.symlinks,
stdin=stdin_fd,
stdout=stdout_fd,
stderr=subprocess.PIPE,
wall_time=case.config.wall_time_factor * self.problem.time_limit,
self._user_procs.append(
self.binary.launch(
time=self.problem.time_limit,
memory=self.problem.memory_limit,
symlinks=case.config.symlinks,
stdin=stdin_fd,
stdout=stdout_fd,
stderr=subprocess.PIPE,
wall_time=case.config.wall_time_factor * self.problem.time_limit,
)
)
self._user_results.append(Result(case))

# Close file descriptors passed to the process
os.close(stdin_fd)
os.close(stdout_fd)

def _interact_with_process(self, case, result, input):
def _interact_with_process(self, case: TestCase, result: Result, input: bytes) -> bytes:
result.proc_output, self._manager_stderr = self._manager_proc.communicate(input)

self._manager_proc.wait()
Expand All @@ -160,7 +180,7 @@ def _interact_with_process(self, case, result, input):

return self._manager_stderr

def _generate_binary(self):
def _generate_binary(self) -> BaseExecutor:
if 'signature' not in self.problem.config.communication:
return super()._generate_binary()

Expand Down Expand Up @@ -202,13 +222,15 @@ def _generate_binary(self):
else:
raise InternalError('no valid runtime for signature grading %s found' % self.language)

def _generate_manager_binary(self):
def _generate_manager_binary(self) -> BaseExecutor:
files = self.handler_data.manager.files
if isinstance(files, str):
filenames = [files]
elif isinstance(files.unwrap(), list):
filenames = list(files.unwrap())
filenames = [os.path.join(get_problem_root(self.problem.id), f) for f in filenames]
problem_root = get_problem_root(self.problem.id)
assert problem_root is not None
filenames = [os.path.join(problem_root, f) for f in filenames]
flags = self.handler_data.manager.get('flags', [])
unbuffered = self.handler_data.manager.get('unbuffered', True)
lang = self.handler_data.manager.lang
Expand Down

0 comments on commit d3c8854

Please sign in to comment.