Skip to content

Commit

Permalink
summaries: enforce a minimum payload length
Browse files Browse the repository at this point in the history
  • Loading branch information
quitrk committed Jan 17, 2024
1 parent bdb41fd commit 3c645b1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
2 changes: 2 additions & 0 deletions skynet/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
# jobs
job_timeout = int(os.environ.get('JOB_TIMEOUT', 60 * 10)) # 10 minutes default

# summaries
summary_minimum_payload_length = int(os.environ.get('SUMMARY_MINIMUM_PAYLOAD_LENGTH', 100))

# monitoring
enable_metrics = os.environ.get('ENABLE_METRICS', 'true').lower() == 'true'
26 changes: 15 additions & 11 deletions skynet/modules/ttt/summaries/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import uuid

from skynet.env import job_timeout, modules, redis_exp_seconds
from skynet.env import job_timeout, modules, redis_exp_seconds, summary_minimum_payload_length
from skynet.logs import get_logger
from skynet.modules.monitoring import (
SUMMARY_DURATION_METRIC,
Expand Down Expand Up @@ -106,26 +106,30 @@ async def run_job(job: Job) -> None:

SUMMARY_TIME_IN_QUEUE_METRIC.observe(start - job.created)

log.info(f"Job {job.id} created at {job.created}")
log.info(f"Job {job.id} started at {start}")
log.info(f"Job queue time: {start - job.created} seconds")

await update_job(job_id=job.id, start=start, status=JobStatus.RUNNING, worker_id=worker_id)

# add to running jobs list if not already there (which may occur on multiple worker disconnects while running the same job)
if job.id not in await db.lrange(RUNNING_JOBS_KEY, 0, -1):
await db.rpush(RUNNING_JOBS_KEY, job.id)

exit_task = asyncio.create_task(exit_on_timeout())
if len(job.payload.text) < summary_minimum_payload_length:
log.warning(f"Job {job.id} failed because payload is too short: \"{job.payload.text}\"")

result = job.payload.text
else:
exit_task = asyncio.create_task(exit_on_timeout())

try:
result = await process(job)
except Exception as e:
log.warning(f"Job {job.id} failed: {e}")
try:
result = await process(job)
except Exception as e:
log.warning(f"Job {job.id} failed: {e}")

has_failed = True
result = str(e)
has_failed = True
result = str(e)

exit_task.cancel()
exit_task.cancel()

updated_job = await update_job(
expires=redis_exp_seconds if not has_failed else None,
Expand Down
49 changes: 46 additions & 3 deletions skynet/modules/ttt/summaries/jobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def default_session_fixture() -> Iterator[None]:

class TestCreateJob:
@pytest.mark.asyncio
async def test_runs_job(self, mocker):
'''Test that a job is run.'''
async def test_creates_run_job(self, mocker):
'''Test that a job run task is created.'''

mocker.patch('skynet.modules.ttt.summaries.jobs.create_run_job_task'),
mocker.patch('skynet.modules.ttt.summaries.jobs.create_run_job_task')

from skynet.modules.ttt.summaries.jobs import create_job, create_run_job_task

Expand All @@ -33,6 +33,7 @@ async def test_runs_job(self, mocker):
async def test_queues_job(self, mocker):
'''Test that a job is queued and queue size metric is updated.'''

mocker.patch('skynet.modules.monitoring.SUMMARY_DURATION_METRIC.observe')
mocker.patch('skynet.modules.ttt.summaries.jobs.can_run_next_job', return_value=False)
mocker.patch('skynet.modules.ttt.summaries.jobs.update_summary_queue_metric')

Expand All @@ -44,6 +45,48 @@ async def test_queues_job(self, mocker):
update_summary_queue_metric.assert_called_once()


@pytest.fixture()
def run_job_fixture(mocker):
mocker.patch('skynet.modules.ttt.summaries.jobs.SUMMARY_DURATION_METRIC.observe')
mocker.patch('skynet.modules.ttt.summaries.jobs.update_job')
mocker.patch('skynet.modules.ttt.summaries.jobs.process')
mocker.patch('skynet.modules.ttt.summaries.jobs.db.db')

yield 'run_job_fixture'


class TestRunJob:
@pytest.mark.asyncio
async def test_does_not_run_job(self, run_job_fixture):
'''Test that a job with a short payload is not sent for inference.'''

from skynet.modules.ttt.summaries.jobs import process, run_job

await run_job(
Job(payload=DocumentPayload(text="Hello. It’s me . . . Where are you?"), type=JobType.SUMMARY, id='job_id')
)

process.assert_not_called()

@pytest.mark.asyncio
async def test_run_job(self, run_job_fixture):
'''Test that a job with a long enough payload is sent for inference.'''

from skynet.modules.ttt.summaries.jobs import process, run_job

await run_job(
Job(
payload=DocumentPayload(
text="Andrew: Hello. Beatrix: Honey? It’s me . . . Andrew: Where are you? Beatrix: At the station. I missed my train."
),
type=JobType.SUMMARY,
id='job_id',
)
)

process.assert_called_once()


class TestCanRunNextJob:
def test_returns_true_if_executor_enabled(self, mocker):
'''Test that it returns true if executor module is enabled.'''
Expand Down

0 comments on commit 3c645b1

Please sign in to comment.