Skip to content

Commit

Permalink
Support optional instance volumes (#2260)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvstme authored Feb 4, 2025
1 parent 8aada88 commit b625d3a
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 20 deletions.
23 changes: 21 additions & 2 deletions docs/docs/concepts/volumes.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,27 @@ volumes:
Since persistence isn't guaranteed (instances may be interrupted or runs may occur on different instances), use instance
volumes only for caching or with directories manually mounted to network storage.

> Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`,
> and can also be used with [SSH fleets](fleets.md#ssh).
!!! info "Backends"
Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`, and can also be used with [SSH fleets](fleets.md#ssh).

??? info "Optional volumes"
If the volume is not critical for your workload, you can mark it as `optional`.

<div editor-title=".dstack.yml">

```yaml
type: task

volumes:
- instance_path: /dstack-cache
path: /root/.cache/
optional: true
```

Configurations with optional volumes can run in any backend, but the volume is only mounted
if the selected backend supports it.

</div>

### Use instance volumes for caching

Expand Down
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/models/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def parse(cls, v: str) -> Self:
class InstanceMountPoint(CoreModel):
instance_path: Annotated[str, Field(description="The absolute path on the instance (host)")]
path: Annotated[str, Field(description="The absolute path in the container")]
optional: Annotated[
bool,
Field(
description=(
"Allow running without this volume"
" in backends that do not support instance volumes"
),
),
] = False

_validate_instance_path = validator("instance_path", allow_reuse=True)(
_validate_mount_point_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from dstack._internal.server.services.runs import (
check_can_attach_run_volumes,
check_run_spec_has_instance_mounts,
check_run_spec_requires_instance_mounts,
get_offer_volumes,
get_run_volume_models,
get_run_volumes,
Expand Down Expand Up @@ -418,7 +418,7 @@ async def _run_job_on_new_instance(
master_job_provisioning_data=master_job_provisioning_data,
volumes=volumes,
privileged=job.job_spec.privileged,
instance_mounts=check_run_spec_has_instance_mounts(run.run_spec),
instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec),
)
# Limit number of offers tried to prevent long-running processing
# in case all offers fail.
Expand Down
7 changes: 4 additions & 3 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ async def get_plan(
multinode=jobs[0].job_spec.jobs_per_replica > 1,
volumes=volumes,
privileged=jobs[0].job_spec.privileged,
instance_mounts=check_run_spec_has_instance_mounts(run_spec),
instance_mounts=check_run_spec_requires_instance_mounts(run_spec),
)

job_plans = []
Expand Down Expand Up @@ -897,9 +897,10 @@ def get_offer_mount_point_volume(
raise ServerClientError("Failed to find an eligible volume for the mount point")


def check_run_spec_has_instance_mounts(run_spec: RunSpec) -> bool:
def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:
return any(
is_core_model_instance(mp, InstanceMountPoint) for mp in run_spec.configuration.volumes
is_core_model_instance(mp, InstanceMountPoint) and not mp.optional
for mp in run_spec.configuration.volumes
)


Expand Down
26 changes: 16 additions & 10 deletions src/dstack/api/server/_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
from uuid import UUID

from pydantic import parse_obj_as
Expand All @@ -19,6 +19,7 @@
RunPlan,
RunSpec,
)
from dstack._internal.core.models.volumes import InstanceMountPoint
from dstack._internal.server.schemas.runs import (
ApplyRunPlanRequest,
CreateInstanceRequest,
Expand Down Expand Up @@ -122,42 +123,47 @@ def create_instance(

def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]:
spec_excludes: dict[str, set[str]] = {}
configuration_excludes: set[str] = set()
configuration_excludes: dict[str, Any] = {}
profile_excludes: set[str] = set()
configuration = run_spec.configuration
profile = run_spec.profile

# client >= 0.18.18 / server <= 0.18.17 compatibility tweak
if not configuration.privileged:
configuration_excludes.add("privileged")
configuration_excludes["privileged"] = True
# client >= 0.18.23 / server <= 0.18.22 compatibility tweak
if configuration.type == "service" and configuration.gateway is None:
configuration_excludes.add("gateway")
configuration_excludes["gateway"] = True
# client >= 0.18.30 / server <= 0.18.29 compatibility tweak
if run_spec.configuration.user is None:
configuration_excludes.add("user")
configuration_excludes["user"] = True
# client >= 0.18.30 / server <= 0.18.29 compatibility tweak
if configuration.reservation is None:
configuration_excludes.add("reservation")
configuration_excludes["reservation"] = True
if profile is not None and profile.reservation is None:
profile_excludes.add("reservation")
if configuration.idle_duration is None:
configuration_excludes.add("idle_duration")
configuration_excludes["idle_duration"] = True
if profile is not None and profile.idle_duration is None:
profile_excludes.add("idle_duration")
# client >= 0.18.38 / server <= 0.18.37 compatibility tweak
if configuration.stop_duration is None:
configuration_excludes.add("stop_duration")
configuration_excludes["stop_duration"] = True
if profile is not None and profile.stop_duration is None:
profile_excludes.add("stop_duration")
# client >= 0.18.40 / server <= 0.18.39 compatibility tweak
if (
is_core_model_instance(configuration, ServiceConfiguration)
and configuration.strip_prefix == STRIP_PREFIX_DEFAULT
):
configuration_excludes.add("strip_prefix")
configuration_excludes["strip_prefix"] = True
if configuration.single_branch is None:
configuration_excludes.add("single_branch")
configuration_excludes["single_branch"] = True
if all(
not is_core_model_instance(v, InstanceMountPoint) or not v.optional
for v in configuration.volumes
):
configuration_excludes["volumes"] = {"__all__": {"optional"}}

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,78 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance
await session.refresh(pool)
assert not pool.instances

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_provisions_job_with_optional_instance_volume_not_attached(
self,
test_db,
session: AsyncSession,
):
project = await create_project(session=session)
user = await create_user(session=session)
pool = await create_pool(session=session, project=project)
repo = await create_repo(
session=session,
project_id=project.id,
)
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
run_spec.configuration.volumes = [
InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True)
]
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
run_name="test-run",
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
instance_assigned=True,
)
offer = InstanceOfferWithAvailability(
backend=BackendType.RUNPOD,
instance=InstanceType(
name="instance",
resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]),
),
region="us",
price=1.0,
availability=InstanceAvailability.AVAILABLE,
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
instance_id="instance_id",
hostname="1.1.1.1",
internal_ip=None,
region=offer.region,
price=offer.price,
username="ubuntu",
ssh_port=22,
ssh_proxy=None,
dockerized=False,
backend_data=None,
)
await process_submitted_jobs()

await session.refresh(job)
assert job is not None
assert job.status == JobStatus.PROVISIONING

await session.refresh(pool)
instance_offer = InstanceOfferWithAvailability.parse_raw(pool.instances[0].offer)
assert offer == instance_offer
pool_job_provisioning_data = pool.instances[0].job_provisioning_data
assert pool_job_provisioning_data == job.job_provisioning_data

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
Expand Down Expand Up @@ -412,7 +484,8 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
run_spec.configuration.volumes = [
VolumeMountPoint(name=volume.name, path="/volume"),
InstanceMountPoint(instance_path="/root/.cache", path="/cache"),
InstanceMountPoint(instance_path="/root/.data", path="/data"),
InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True),
]
run = await create_run(
session=session,
Expand Down
8 changes: 6 additions & 2 deletions src/tests/_internal/server/services/runner/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter):
"device_name": "/dev/sdv",
}
],
"instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}],
"instance_mounts": [
{"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False}
],
}
self.assert_request(adapter, 0, "POST", "/api/submit", expected_request)

Expand Down Expand Up @@ -341,7 +343,9 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter):
}
],
"volume_mounts": [{"name": "vol", "path": "/vol"}],
"instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}],
"instance_mounts": [
{"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False}
],
"host_ssh_user": "dstack",
"host_ssh_keys": ["host_key"],
"container_ssh_keys": ["project_key", "user_key"],
Expand Down

0 comments on commit b625d3a

Please sign in to comment.