Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-download files from the staging directory to output #500

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jupyter_scheduler.orm import create_session, create_tables
from jupyter_scheduler.scheduler import Scheduler
from jupyter_scheduler.tests.mocks import MockEnvironmentManager
from jupyter_scheduler.tests.mocks import MockDownloadManager, MockEnvironmentManager

pytest_plugins = ("jupyter_server.pytest_plugin",)

Expand Down Expand Up @@ -48,5 +48,8 @@ def jp_scheduler_db():
@pytest.fixture
def jp_scheduler():
return Scheduler(
db_url=DB_URL, root_dir=str(TEST_ROOT_DIR), environments_manager=MockEnvironmentManager()
db_url=DB_URL,
root_dir=(TEST_ROOT_DIR),
environments_manager=MockEnvironmentManager(),
download_manager=MockDownloadManager(DB_URL),
)
84 changes: 84 additions & 0 deletions jupyter_scheduler/download_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from multiprocessing import Queue
from typing import List, Optional

from jupyter_scheduler.models import DescribeDownload
from jupyter_scheduler.orm import Download, create_session, generate_uuid
from jupyter_scheduler.pydantic_v1 import BaseModel
from jupyter_scheduler.utils import get_utc_timestamp


def initiate_download_standalone(
job_id: str, download_queue: Queue, db_session, redownload: bool = False
):
"""
This method initiates a download in a standalone manner independent of the DownloadManager instance. It is suitable for use in multiprocessing environment where a direct reference to DownloadManager instance is not feasible.
"""
download_initiated_time = get_utc_timestamp()
download_id = generate_uuid()
download = DescribeDownload(
job_id=job_id,
download_id=download_id,
download_initiated_time=download_initiated_time,
redownload=redownload,
)
download_record = Download(**download.dict())
db_session.add(download_record)
db_session.commit()
download_queue.put(download)


class DownloadRecordManager:
def __init__(self, db_url):
self.session = create_session(db_url)

def put(self, download: DescribeDownload):
with self.session() as session:
download = Download(**download.dict())
session.add(download)
session.commit()

def get(self, download_id: str) -> Optional[DescribeDownload]:
with self.session() as session:
download = session.query(Download).filter(Download.download_id == download_id).first()

if download:
return DescribeDownload.from_orm(download)
else:
return None

def get_downloads(self) -> List[DescribeDownload]:
with self.session() as session:
return session.query(Download).order_by(Download.download_initiated_time).all()

def delete_download(self, download_id: str):
with self.session() as session:
session.query(Download).filter(Download.download_id == download_id).delete()
session.commit()

def delete_job_downloads(self, job_id: str):
with self.session() as session:
session.query(Download).filter(Download.job_id == job_id).delete()
session.commit()


class DownloadManager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this class with DownloadRecordManager above? I don't see the benefit of splitting the logic here into two separate classes if they are only used together anyways.

def __init__(self, db_url: str):
self.record_manager = DownloadRecordManager(db_url=db_url)
self.queue = Queue()

def initiate_download(self, job_id: str, redownload: bool):
with self.record_manager.session() as session:
initiate_download_standalone(
job_id=job_id, download_queue=self.queue, db_session=session, redownload=redownload
)

def delete_download(self, download_id: str):
self.record_manager.delete_download(download_id)

def delete_job_downloads(self, job_id: str):
self.record_manager.delete_job_downloads(job_id)

def populate_queue(self):
downloads = self.record_manager.get_downloads()
for download in downloads:
self.queue.put(download)
57 changes: 57 additions & 0 deletions jupyter_scheduler/download_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio

import traitlets
from jupyter_server.transutils import _i18n
from traitlets.config import LoggingConfigurable

from jupyter_scheduler.download_manager import DownloadManager
from jupyter_scheduler.job_files_manager import JobFilesManager


class BaseDownloadRunner(LoggingConfigurable):
"""Base download runner, this class's start method is called
at the start of jupyter server, and is responsible for
polling for downloads to download.
"""

def __init__(self, config=None, **kwargs):
super().__init__(config=config)

downloads_poll_interval = traitlets.Integer(
default_value=3,
config=True,
help=_i18n(
"The interval in seconds that the download runner polls for downloads to download."
),
)

def start(self):
raise NotImplementedError("Must be implemented by subclass")


class DownloadRunner(BaseDownloadRunner):
"""Default download runner that maintains a record and a queue of initiated downloads , and polls the queue every `poll_interval` seconds
for downloads to download.
"""

def __init__(
self, download_manager: DownloadManager, job_files_manager: JobFilesManager, config=None
):
super().__init__(config=config)
self.download_manager = download_manager
self.job_files_manager = job_files_manager

async def process_download_queue(self):
while not self.download_manager.queue.empty():
download = self.download_manager.queue.get()
download_record = self.download_manager.record_manager.get(download.download_id)
if not download_record:
continue
await self.job_files_manager.copy_from_staging(download.job_id, download.redownload)
self.download_manager.delete_download(download.download_id)
Comment on lines +44 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid using multiprocessing.Queue if we are already writing pending downloads to a DB. Can we read directly from self.download_manager.record_manager instead?

I believe that this may fix the process bug previously raised on the E2E tests in this branch. This is the corresponding error message:

RuntimeError: A SemLock created in a fork context is being shared with a process in a spawn context. This is not supported. Please use the same context to create multiprocessing objects and Process.

If we remove the need for multiprocessing objects, we may be able to fix this bug without relying on multiprocessing.set_start_method(). Can you give this a try?


async def start(self):
self.download_manager.populate_queue()
while True:
await self.process_download_queue()
await asyncio.sleep(self.downloads_poll_interval)
18 changes: 17 additions & 1 deletion jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import nbformat
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor

from jupyter_scheduler.download_manager import initiate_download_standalone
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
from jupyter_scheduler.orm import Job, create_session
from jupyter_scheduler.parameterize import add_parameters
Expand All @@ -29,11 +30,19 @@ class ExecutionManager(ABC):
_model = None
_db_session = None

def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
def __init__(
self,
job_id: str,
root_dir: str,
db_url: str,
staging_paths: Dict[str, str],
download_queue,
):
self.job_id = job_id
self.staging_paths = staging_paths
self.root_dir = root_dir
self.db_url = db_url
self.download_queue = download_queue

@property
def model(self):
Expand Down Expand Up @@ -143,6 +152,13 @@ def execute(self):
finally:
self.add_side_effects_files(staging_dir)
self.create_output_files(job, nb)
with self.db_session() as session:
initiate_download_standalone(
job_id=job.job_id,
download_queue=self.download_queue,
db_session=session,
redownload=True,
)

def add_side_effects_files(self, staging_dir: str):
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
Expand Down
24 changes: 24 additions & 0 deletions jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import multiprocessing

from jupyter_core.paths import jupyter_data_dir
from jupyter_server.extension.application import ExtensionApp
from jupyter_server.transutils import _i18n
from traitlets import Bool, Type, Unicode, default

from jupyter_scheduler.download_manager import DownloadManager
from jupyter_scheduler.download_runner import DownloadRunner
from jupyter_scheduler.orm import create_tables

from .handlers import (
Expand Down Expand Up @@ -67,27 +70,48 @@ def _db_url_default(self):
)

def initialize_settings(self):
# Forces new processes to not be forked on Linux.
# This is necessary because `asyncio.get_event_loop()` is bugged in
# forked processes in Python versions below 3.12. This method is
# called by `jupyter_core` by `nbconvert` in the default executor.

# See: https://github.com/python/cpython/issues/66285
# See also: https://github.com/jupyter/jupyter_core/pull/362
multiprocessing.set_start_method("spawn", force=True)

Comment on lines +73 to +81
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is that this line affects the multiprocessing behavior globally for everything running on this main thread, i.e. the server and all server extensions running on a JupyterLab instance. We should really avoid doing this just to pass our GitHub workflows. Consider what happens if:

  • Another extension is also calling multiprocessing.set_start_method(..., force=True), or
  • Some part of the server / server extension breaks when the start method is changed during its lifetime by our extension's initialize_settings() method.

I don't have a solution for how this bug can be fixed. However, the error message is pretty specific about why an exception is being raised, so my intuition is that this bug can be fixed. I'm leaving some references here for us to review in the future.

super().initialize_settings()

create_tables(self.db_url, self.drop_tables)

environments_manager = self.environment_manager_class()

download_manager = DownloadManager(db_url=self.db_url)

scheduler = self.scheduler_class(
root_dir=self.serverapp.root_dir,
environments_manager=environments_manager,
db_url=self.db_url,
download_manager=download_manager,
config=self.config,
)

job_files_manager = self.job_files_manager_class(scheduler=scheduler)

download_runner = DownloadRunner(
download_manager=download_manager, job_files_manager=job_files_manager
)

self.settings.update(
environments_manager=environments_manager,
scheduler=scheduler,
job_files_manager=job_files_manager,
initiate_download=download_manager.initiate_download,
)

if scheduler.task_runner:
loop = asyncio.get_event_loop()
loop.create_task(scheduler.task_runner.start())

if download_runner:
loop = asyncio.get_event_loop()
loop.create_task(download_runner.start())
12 changes: 6 additions & 6 deletions jupyter_scheduler/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,20 +395,20 @@ def get(self):


class FilesDownloadHandler(ExtensionHandlerMixin, APIHandler):
_job_files_manager = None
_initiate_download = None

@property
def job_files_manager(self):
if not self._job_files_manager:
self._job_files_manager = self.settings.get("job_files_manager", None)
def initiate_download(self):
if not self._initiate_download:
self._initiate_download = self.settings.get("initiate_download", None)

return self._job_files_manager
return self._initiate_download

@authenticated
async def get(self, job_id):
redownload = self.get_query_argument("redownload", False)
try:
await self.job_files_manager.copy_from_staging(job_id=job_id, redownload=redownload)
self.initiate_download(job_id, redownload)
except Exception as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e
Expand Down
13 changes: 13 additions & 0 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,16 @@ class JobFeature(str, Enum):
output_filename_template = "output_filename_template"
stop_job = "stop_job"
delete_job = "delete_job"


class DescribeDownload(BaseModel):
job_id: str
download_id: str
download_initiated_time: int
redownload: bool

class Config:
orm_mode = True

def __str__(self) -> str:
return self.json()
9 changes: 8 additions & 1 deletion jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os
from sqlite3 import OperationalError
from uuid import uuid4

Expand Down Expand Up @@ -112,6 +111,14 @@ class JobDefinition(CommonColumns, Base):
active = Column(Boolean, default=True)


class Download(Base):
__tablename__ = "downloads"
job_id = Column(String(36), primary_key=True)
download_id = Column(String(36), primary_key=True)
download_initiated_time = Column(Integer)
redownload = Column(Boolean, default=False)


def create_tables(db_url, drop_tables=False):
engine = create_engine(db_url)
try:
Expand Down
17 changes: 7 additions & 10 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing as mp
import os
import random
import shutil
from multiprocessing import Process
from typing import Dict, List, Optional, Type, Union

import fsspec
Expand All @@ -15,6 +15,7 @@
from traitlets import Unicode, default
from traitlets.config import LoggingConfigurable

from jupyter_scheduler.download_manager import DownloadManager
from jupyter_scheduler.environments import EnvironmentManager
from jupyter_scheduler.exceptions import (
IdempotencyTokenError,
Expand Down Expand Up @@ -404,6 +405,7 @@ def __init__(
root_dir: str,
environments_manager: Type[EnvironmentManager],
db_url: str,
download_manager: DownloadManager,
config=None,
**kwargs,
):
Expand All @@ -413,6 +415,7 @@ def __init__(
self.db_url = db_url
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)
self.download_manager = download_manager

@property
def db_session(self):
Expand Down Expand Up @@ -478,20 +481,13 @@ def create_job(self, model: CreateJob) -> str:
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

# The MP context forces new processes to not be forked on Linux.
# This is necessary because `asyncio.get_event_loop()` is bugged in
# forked processes in Python versions below 3.12. This method is
# called by `jupyter_core` by `nbconvert` in the default executor.
#
# See: https://github.com/python/cpython/issues/66285
# See also: https://github.com/jupyter/jupyter_core/pull/362
mp_ctx = mp.get_context("spawn")
p = mp_ctx.Process(
p = Process(
target=self.execution_manager_class(
job_id=job.job_id,
staging_paths=staging_paths,
root_dir=self.root_dir,
db_url=self.db_url,
download_queue=self.download_manager.queue,
).process
)
p.start()
Expand Down Expand Up @@ -583,6 +579,7 @@ def delete_job(self, job_id: str):

session.query(Job).filter(Job.job_id == job_id).delete()
session.commit()
self.download_manager.delete_job_downloads(job_id)

def stop_job(self, job_id):
with self.db_session() as session:
Expand Down
Loading
Loading