Skip to content

Log notebooks with MLFlow #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import io
import os
import shutil
Expand All @@ -7,13 +8,15 @@
from typing import Dict

import fsspec
import mlflow
import nbconvert
import nbformat
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor

from jupyter_scheduler.models import DescribeJob, JobFeature, Status
from jupyter_scheduler.orm import Job, create_session
from jupyter_scheduler.parameterize import add_parameters
from jupyter_scheduler.scheduler import MLFLOW_SERVER_URI
from jupyter_scheduler.utils import get_utc_timestamp


Expand Down Expand Up @@ -143,6 +146,8 @@ def execute(self):
finally:
self.add_side_effects_files(staging_dir)
self.create_output_files(job, nb)
if getattr(job, "mlflow_logging", False):
self.log_to_mlflow(job, nb)

def add_side_effects_files(self, staging_dir: str):
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
Expand Down Expand Up @@ -173,6 +178,103 @@ def create_output_files(self, job: DescribeJob, notebook_node):
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
f.write(output)

def log_to_mlflow(self, job, nb):
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
with mlflow.start_run(run_id=job.mlflow_run_id):
if job.parameters:
mlflow.log_params(job.parameters)

for cell_idx, cell in enumerate(nb.cells):
if "tags" in cell.metadata:
if "mlflow_log" in cell.metadata["tags"]:
self.mlflow_log(cell, cell_idx)
elif "mlflow_log_input" in cell.metadata["tags"]:
self.mlflow_log_input(cell, cell_idx)
elif "mlflow_log_output" in cell.metadata["tags"]:
self.mlflow_log_output(cell, cell_idx)

for output_format in job.output_formats:
output_path = self.staging_paths[output_format]
directory, file_name_with_extension = os.path.split(output_path)
file_name, file_extension = os.path.splitext(file_name_with_extension)
file_name_parts = file_name.split("-")
file_name_without_timestamp = "-".join(file_name_parts[:-7])
file_name_final = f"{file_name_without_timestamp}{file_extension}"
new_output_path = os.path.join(directory, file_name_final)
shutil.copy(output_path, new_output_path)
timestamp = "-".join(file_name_parts[-7:]).split(".")[0]
mlflow.log_param("job_created", timestamp)
mlflow.log_artifact(new_output_path, "")
os.remove(new_output_path)

def mlflow_log(self, cell, cell_idx):
self.mlflow_log_input(cell, cell_idx)
self.mlflow_log_output(cell, cell_idx)

def mlflow_log_input(self, cell, cell_idx):
mlflow.log_text(cell.source, f"cell_{cell_idx}_input.txt")

def mlflow_log_output(self, cell, cell_idx):
if cell.cell_type == "code" and hasattr(cell, "outputs"):
self._log_code_output(cell_idx, cell.outputs)
elif cell.cell_type == "markdown":
self._log_markdown_output(cell, cell_idx)

def _log_code_output(self, cell_idx, outputs):
for output_idx, output in enumerate(outputs):
if output.output_type == "stream":
self._log_stream_output(cell_idx, output_idx, output)
elif hasattr(output, "data"):
for output_data_idx, output_data in enumerate(output.data):
if output_data == "text/plain":
mlflow.log_text(
output.data[output_data],
f"cell_{cell_idx}_output_{output_data_idx}.txt",
)
elif output_data == "text/html":
self._log_html_output(output, cell_idx, output_data_idx)
elif output_data == "application/pdf":
self._log_pdf_output(output, cell_idx, output_data_idx)
elif output_data.startswith("image"):
self._log_image_output(output, cell_idx, output_data_idx, output_data)

def _log_stream_output(self, cell_idx, output_idx, output):
mlflow.log_text("".join(output.text), f"cell_{cell_idx}_output_{output_idx}.txt")

def _log_html_output(self, output, cell_idx, output_idx):
if "text/html" in output.data:
html_content = output.data["text/html"]
if isinstance(html_content, list):
html_content = "".join(html_content)
mlflow.log_text(html_content, f"cell_{cell_idx}_output_{output_idx}.html")

def _log_pdf_output(self, output, cell_idx, output_idx):
pdf_data = base64.b64decode(output.data["application/pdf"].split(",")[1])
with open(f"cell_{cell_idx}_output_{output_idx}.pdf", "wb") as pdf_file:
pdf_file.write(pdf_data)
mlflow.log_artifact(f"cell_{cell_idx}_output_{output_idx}.pdf")

def _log_image_output(self, output, cell_idx, output_idx, mime_type):
image_data_str = output.data[mime_type]
if "," in image_data_str:
image_data_base64 = image_data_str.split(",")[1]
else:
image_data_base64 = image_data_str

try:
image_data = base64.b64decode(image_data_base64)
image_extension = mime_type.split("/")[1]
filename = f"cell_{cell_idx}_output_{output_idx}.{image_extension}"
with open(filename, "wb") as image_file:
image_file.write(image_data)
mlflow.log_artifact(filename)
os.remove(filename)
except Exception as e:
print(f"Error logging image output in cell {cell_idx}, output {output_idx}: {e}")

def _log_markdown_output(self, cell, cell_idx):
mlflow.log_text(cell.source, f"cell_{cell_idx}_output_0.md")

def supported_features(cls) -> Dict[JobFeature, bool]:
return {
JobFeature.job_name: True,
Expand Down
10 changes: 10 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class CreateJob(BaseModel):
output_filename_template: Optional[str] = OUTPUT_FILENAME_TEMPLATE
compute_type: Optional[str] = None
package_input_folder: Optional[bool] = None
mlflow_logging: Optional[bool] = None
mlflow_experiment_id: Optional[str] = None
mlflow_run_id: Optional[str] = None

@root_validator
def compute_input_filename(cls, values) -> Dict:
Expand Down Expand Up @@ -148,6 +151,9 @@ class DescribeJob(BaseModel):
downloaded: bool = False
package_input_folder: Optional[bool] = None
packaged_files: Optional[List[str]] = []
mlflow_logging: Optional[bool] = None
mlflow_experiment_id: Optional[str] = None
mlflow_run_id: Optional[str] = None

class Config:
orm_mode = True
Expand Down Expand Up @@ -213,6 +219,8 @@ class CreateJobDefinition(BaseModel):
schedule: Optional[str] = None
timezone: Optional[str] = None
package_input_folder: Optional[bool] = None
mlflow_logging: Optional[bool] = None
mlflow_experiment_id: Optional[str] = None

@root_validator
def compute_input_filename(cls, values) -> Dict:
Expand Down Expand Up @@ -240,6 +248,8 @@ class DescribeJobDefinition(BaseModel):
active: bool
package_input_folder: Optional[bool] = None
packaged_files: Optional[List[str]] = []
mlflow_logging: Optional[bool] = None
mlflow_experiment_id: Optional[str] = None

class Config:
orm_mode = True
Expand Down
3 changes: 3 additions & 0 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class CommonColumns:
# Any default values specified for new columns will be ignored during the migration process.
package_input_folder = Column(Boolean)
packaged_files = Column(JsonType, default=[])
mlflow_logging = Column(Boolean)
mlflow_experiment_id = Column(String(256), nullable=True)


class Job(CommonColumns, Base):
Expand All @@ -105,6 +107,7 @@ class Job(CommonColumns, Base):
idempotency_token = Column(String(256))
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
# Any default values specified for new columns will be ignored during the migration process.
mlflow_run_id = Column(String(256), nullable=True)


class JobDefinition(CommonColumns, Base):
Expand Down
61 changes: 61 additions & 0 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import os
import random
import shutil
import signal
import subprocess
import sys
from typing import Dict, List, Optional, Type, Union
from uuid import uuid4

import fsspec
import mlflow
import psutil
from jupyter_core.paths import jupyter_data_dir
from jupyter_server.transutils import _i18n
Expand Down Expand Up @@ -45,6 +50,10 @@
create_output_filename,
)

MLFLOW_SERVER_HOST = "127.0.0.1"
MLFLOW_SERVER_PORT = "5000"
MLFLOW_SERVER_URI = f"http://{MLFLOW_SERVER_HOST}:{MLFLOW_SERVER_PORT}"


class BaseScheduler(LoggingConfigurable):
"""Base class for schedulers. A default implementation
Expand Down Expand Up @@ -399,6 +408,33 @@ class Scheduler(BaseScheduler):

task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")

def start_mlflow_server(self):
mlflow_process = subprocess.Popen(
[
"mlflow",
"server",
"--host",
MLFLOW_SERVER_HOST,
"--port",
MLFLOW_SERVER_PORT,
],
preexec_fn=os.setsid,
)
mlflow.set_tracking_uri(MLFLOW_SERVER_URI)
return mlflow_process

def stop_mlflow_server(self):
if self.mlflow_process is not None:
os.killpg(os.getpgid(self.mlflow_process.pid), signal.SIGTERM)
self.mlflow_process.wait()
self.mlflow_process = None
print("MLFlow server stopped")

def mlflow_signal_handler(self, signum, frame):
print("Shutting down MLFlow server")
self.stop_mlflow_server()
sys.exit(0)

def __init__(
self,
root_dir: str,
Expand All @@ -414,6 +450,10 @@ def __init__(
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

self.mlflow_process = self.start_mlflow_server()
signal.signal(signal.SIGINT, self.mlflow_signal_handler)
signal.signal(signal.SIGTERM, self.mlflow_signal_handler)

@property
def db_session(self):
if not self._db_session:
Expand Down Expand Up @@ -462,6 +502,21 @@ def create_job(self, model: CreateJob) -> str:
if not model.output_formats:
model.output_formats = []

mlflow_client = mlflow.MlflowClient()

if model.job_definition_id and model.mlflow_experiment_id:
experiment_id = model.mlflow_experiment_id
else:
experiment_id = mlflow_client.create_experiment(f"{model.input_filename}-{uuid4()}")
model.mlflow_experiment_id = experiment_id
input_file_path = os.path.join(self.root_dir, model.input_uri)
mlflow.log_artifact(input_file_path, "input")

mlflow_run = mlflow_client.create_run(
experiment_id=experiment_id, run_name=f"{model.input_filename}-{uuid4()}"
)
model.mlflow_run_id = mlflow_run.info.run_id

job = Job(**model.dict(exclude_none=True, exclude={"input_uri"}))

session.add(job)
Expand Down Expand Up @@ -609,6 +664,12 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
if not self.file_exists(model.input_uri):
raise InputUriError(model.input_uri)

mlflow_client = mlflow.MlflowClient()
experiment_id = mlflow_client.create_experiment(f"{model.input_filename}-{uuid4()}")
model.mlflow_experiment_id = experiment_id
input_file_path = os.path.join(self.root_dir, model.input_uri)
mlflow.log_artifact(input_file_path, "input")

job_definition = JobDefinition(**model.dict(exclude_none=True, exclude={"input_uri"}))
session.add(job_definition)
session.commit()
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ dependencies = [
"croniter~=1.4",
"pytz==2023.3",
"fsspec==2023.6.0",
"psutil~=5.9"
"psutil~=5.9",
"mlflow"
]

[project.optional-dependencies]
Expand Down
16 changes: 16 additions & 0 deletions src/components/mlflow-checkbox.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import React, { ChangeEvent } from 'react';

import { Checkbox, FormControlLabel, FormGroup } from '@mui/material';

export function MLFlowLoggingControl(props: {
onChange: (event: ChangeEvent<HTMLInputElement>) => void;
}): JSX.Element {
return (
<FormGroup>
<FormControlLabel
control={<Checkbox onChange={props.onChange} name={'mlflowLogging'} />}
label="Log with MLFlow"
/>
</FormGroup>
);
}
10 changes: 10 additions & 0 deletions src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ export namespace Scheduler {
schedule?: string;
timezone?: string;
package_input_folder?: boolean;
mlflow_logging?: boolean;
mlflow_experiment_id?: string;
}

export interface IUpdateJobDefinition {
Expand Down Expand Up @@ -391,6 +393,8 @@ export namespace Scheduler {
update_time: number;
active: boolean;
package_input_folder?: boolean;
mlflow_logging: boolean;
mlflow_experiment_id?: string;
}

export interface IEmailNotifications {
Expand Down Expand Up @@ -418,6 +422,9 @@ export namespace Scheduler {
output_formats?: string[];
compute_type?: string;
package_input_folder?: boolean;
mlflow_logging?: boolean;
mlflow_experiment_id?: string;
mlflow_run_id?: string;
}

export interface ICreateJobFromDefinition {
Expand Down Expand Up @@ -467,6 +474,9 @@ export namespace Scheduler {
end_time?: number;
downloaded: boolean;
package_input_folder?: boolean;
mlflow_logging?: boolean;
mlflow_experiment_id?: string;
mlflow_run_id?: string;
}

export interface ICreateJobResponse {
Expand Down
Loading
Loading