|
3 | 3 | import shutil
|
4 | 4 | import tarfile
|
5 | 5 | import traceback
|
| 6 | +import multiprocessing as mp |
6 | 7 | from abc import ABC, abstractmethod
|
7 | 8 | from functools import lru_cache
|
8 | 9 | from typing import Dict, List
|
|
15 | 16 | from prefect.futures import as_completed
|
16 | 17 | from prefect_dask.task_runners import DaskTaskRunner
|
17 | 18 |
|
18 |
| -from jupyter_scheduler.models import DescribeJob, JobFeature, Status |
| 19 | +from jupyter_scheduler.models import CreateJob, DescribeJob, JobFeature, Status |
19 | 20 | from jupyter_scheduler.orm import Job, Workflow, create_session
|
20 | 21 | from jupyter_scheduler.parameterize import add_parameters
|
21 | 22 | from jupyter_scheduler.utils import get_utc_timestamp
|
@@ -187,35 +188,58 @@ class DefaultExecutionManager(ExecutionManager):
|
187 | 188 | """Default execution manager that executes notebooks"""
|
188 | 189 |
|
189 | 190 | @task(task_run_name="{task_id}")
|
190 |
| - def execute_task(self, task_id: str): |
191 |
| - print(f"Task {task_id} executed") |
192 |
| - return task_id |
| 191 | + def execute_task(self, job: Job): |
| 192 | + |
| 193 | + # The MP context forces new processes to not be forked on Linux. |
| 194 | + # This is necessary because `asyncio.get_event_loop()` is bugged in |
| 195 | + # forked processes in Python versions below 3.12. This method is |
| 196 | + # called by `jupyter_core` by `nbconvert` in the default executor. |
| 197 | + # |
| 198 | + # See: https://github.com/python/cpython/issues/66285 |
| 199 | + # See also: https://github.com/jupyter/jupyter_core/pull/362 |
| 200 | + mp_ctx = mp.get_context("spawn") |
| 201 | + p = mp_ctx.Process( |
| 202 | + target=self.execution_manager_class( |
| 203 | + job_id=job.job_id, |
| 204 | + staging_paths=self.staging_paths, |
| 205 | + root_dir=self.root_dir, |
| 206 | + db_url=self.db_url, |
| 207 | + ).process |
| 208 | + ) |
| 209 | + p.start() |
| 210 | + |
| 211 | + return job.job_id |
193 | 212 |
|
194 | 213 | @task
|
195 |
| - def get_task_data(self, task_ids: List[str] = []): |
| 214 | + def get_task_data(self, task_ids: List[str]) -> List[Job]: |
196 | 215 | # TODO: get orm objects from Task table of the db, create DescribeTask for each
|
197 |
| - tasks_data_obj = [ |
198 |
| - {"id": "task0", "dependsOn": ["task3"]}, |
199 |
| - {"id": "task4", "dependsOn": ["task0", "task1", "task2", "task3"]}, |
200 |
| - {"id": "task1", "dependsOn": []}, |
201 |
| - {"id": "task2", "dependsOn": ["task1"]}, |
202 |
| - {"id": "task3", "dependsOn": ["task1", "task2"]}, |
203 |
| - ] |
| 216 | + # tasks_data_obj = [ |
| 217 | + # {"id": "task0", "dependsOn": ["task3"]}, |
| 218 | + # {"id": "task4", "dependsOn": ["task0", "task1", "task2", "task3"]}, |
| 219 | + # {"id": "task1", "dependsOn": []}, |
| 220 | + # {"id": "task2", "dependsOn": ["task1"]}, |
| 221 | + # {"id": "task3", "dependsOn": ["task1", "task2"]}, |
| 222 | + # ] |
| 223 | + tasks = [] |
| 224 | + with self.db_session() as session: |
| 225 | + for task_id in task_ids: |
| 226 | + job = session.query(Job).filter(Job.job_id == task_id).first() |
| 227 | + tasks.append(job) |
204 | 228 |
|
205 |
| - return tasks_data_obj |
| 229 | + return tasks |
206 | 230 |
|
207 | 231 | @flow
|
208 | 232 | def execute_workflow(self):
|
209 | 233 |
|
210 |
| - tasks_info = self.get_task_data() |
211 |
| - tasks = {task["id"]: task for task in tasks_info} |
| 234 | + tasks_info: List[Job] = self.get_task_data(self.model.tasks) |
| 235 | + tasks = {task.job_id: task for task in tasks_info} |
212 | 236 |
|
213 | 237 | # create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
|
214 | 238 | @lru_cache(maxsize=None)
|
215 | 239 | def make_task(task_id):
|
216 |
| - deps = tasks[task_id]["dependsOn"] |
| 240 | + deps = tasks[task_id].depends_on |
217 | 241 | return self.execute_task.submit(
|
218 |
| - task_id, wait_for=[make_task(dep_id) for dep_id in deps] |
| 242 | + tasks[task_id], wait_for=[make_task(dep_id) for dep_id in deps] |
219 | 243 | )
|
220 | 244 |
|
221 | 245 | final_tasks = [make_task(task_id) for task_id in tasks]
|
|
0 commit comments