Skip to content

Commit

Permalink
Implement GPU blocks property (#2253)
Browse files Browse the repository at this point in the history
* Add `blocks` property to cloud and SSH fleet
* Reverse instance->job relation, switch from 1-to-1 to many-to-1 (many
  jobs to one instance)
* Add InstanceModel.{total_blocks,busy_blocks} fields
* Generate virtual shared offers (fraction of a real offer) on the fly
* Update CLI and API to display shared resources and instances
* Keep track of volumes used by jobs

Part-of: #1780
Closes: #1780
  • Loading branch information
un-def authored Feb 6, 2025
1 parent d03ac25 commit c1206c5
Show file tree
Hide file tree
Showing 29 changed files with 1,018 additions and 183 deletions.
11 changes: 11 additions & 0 deletions src/dstack/_internal/cli/utils/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def get_fleets_table(
resources = instance.instance_type.resources.pretty_format(include_spot=True)

status = instance.status.value
total_blocks = instance.total_blocks
busy_blocks = instance.busy_blocks
if (
total_blocks is not None
and total_blocks > 1
and total_blocks > busy_blocks
and instance.status == InstanceStatus.BUSY
):
# 1/4 BUSY => 3/4 IDLE
idle_blocks = total_blocks - busy_blocks
status = f"{idle_blocks}/{total_blocks} {InstanceStatus.IDLE.value}"
if (
instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]
and instance.unreachable
Expand Down
19 changes: 15 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def th(s: str) -> str:
InstanceAvailability.BUSY,
}:
availability = offer.availability.value.replace("_", " ").lower()
instance = offer.instance.name
if offer.total_blocks > 1:
instance += f" ({offer.blocks}/{offer.total_blocks})"
offers.add_row(
f"{i}",
offer.backend.replace("remote", "ssh"),
offer.region,
offer.instance.name,
instance,
r.pretty_format(),
"yes" if r.spot else "no",
f"${offer.price:g}",
Expand Down Expand Up @@ -161,13 +164,21 @@ def get_runs_table(
"SUBMITTED": format_date(job.job_submissions[-1].submitted_at),
"ERROR": _get_job_error(job),
}
jpd = job.job_submissions[-1].job_provisioning_data
latest_job_submission = job.job_submissions[-1]
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
resources = jpd.instance_type.resources
instance = jpd.instance_type.name
jrd = latest_job_submission.job_runtime_data
if jrd is not None and jrd.offer is not None:
resources = jrd.offer.instance.resources
if jrd.offer.total_blocks > 1:
instance += f" ({jrd.offer.blocks}/{jrd.offer.total_blocks})"
job_row.update(
{
"BACKEND": f"{jpd.backend.value.replace('remote', 'ssh')} ({jpd.region})",
"INSTANCE": jpd.instance_type.name,
"RESOURCES": jpd.instance_type.resources.pretty_format(include_spot=True),
"INSTANCE": instance,
"RESOURCES": resources.pretty_format(include_spot=True),
"RESERVATION": jpd.reservation,
"PRICE": f"${jpd.price:.4}",
}
Expand Down
26 changes: 26 additions & 0 deletions src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ class SSHHostParams(CoreModel):
] = None
ssh_key: Optional[SSHKey] = None

blocks: Annotated[
Union[Literal["auto"], int],
Field(
description=(
"The amount of blocks to split the instance into, a number or `auto`."
" `auto` means as many as possible."
" The number of GPUs and CPUs must be divisible by the number of blocks."
" Defaults to `1`, i.e. do not split"
),
ge=1,
),
] = 1

@validator("internal_ip")
def validate_internal_ip(cls, value):
if value is None:
Expand Down Expand Up @@ -142,6 +155,19 @@ class InstanceGroupParams(CoreModel):
Field(description="The resources requirements"),
] = ResourcesSpec()

blocks: Annotated[
Union[Literal["auto"], int],
Field(
description=(
"The amount of blocks to split the instance into, a number or `auto`."
" `auto` means as many as possible."
" The number of GPUs and CPUs must be divisible by the number of blocks."
" Defaults to `1`, i.e. do not split"
),
ge=1,
),
] = 1

backends: Annotated[
Optional[List[BackendType]],
Field(description="The backends to consider for provisioning (e.g., `[aws, gcp]`)"),
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class InstanceOffer(CoreModel):
class InstanceOfferWithAvailability(InstanceOffer):
availability: InstanceAvailability
instance_runtime: InstanceRuntime = InstanceRuntime.SHIM
blocks: int = 1
total_blocks: int = 1


class InstanceStatus(str, Enum):
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/models/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ class Instance(CoreModel):
fleet_name: Optional[str] = None
instance_num: int
pool_name: Optional[str] = None
job_name: Optional[str] = None
job_name: Optional[str] = None # deprecated, always None (instance can have more than one job)
hostname: Optional[str] = None
status: InstanceStatus
unreachable: bool = False
termination_reason: Optional[str] = None
created: datetime.datetime
region: Optional[str] = None
price: Optional[float] = None
total_blocks: Optional[int] = None
busy_blocks: int = 0


class PoolInstances(CoreModel):
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ def get_base_backend(self) -> BackendType:


class JobRuntimeData(CoreModel):
"""
Holds various information only available after the job is submitted, such as:
* offer (depends on the instance)
* volumes used by the job
* resource constraints for container (depend on the instance)
* port mapping (reported by the shim only after the container is started)
Some fields are mutable, for example, `ports` only available when the shim starts
the container.
"""

network_mode: NetworkMode
# GPU, CPU, memory resource shares. None means all available (no limit)
gpu: Optional[int] = None
Expand All @@ -240,6 +251,10 @@ class JobRuntimeData(CoreModel):
# None if data is not yet available (on vm-based backends and ssh instances)
# or not applicable (container-based backends)
ports: Optional[dict[int, int]] = None
# List of volumes used by the job
volume_names: Optional[list[str]] = None # None for backward compalibility
# Virtual shared offer
offer: Optional[InstanceOfferWithAvailability] = None # None for backward compalibility


class ClusterInfo(CoreModel):
Expand Down
37 changes: 29 additions & 8 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
get_create_instance_offers,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.offers import is_divisible_into_blocks
from dstack._internal.server.services.placement import (
get_fleet_placement_groups,
placement_group_model_to_placement_group,
Expand Down Expand Up @@ -133,7 +134,7 @@ async def _process_next_instance():
),
InstanceModel.id.not_in(lockset),
)
.options(lazyload(InstanceModel.job))
.options(lazyload(InstanceModel.jobs))
.order_by(InstanceModel.last_processed_at.asc())
.limit(1)
.with_for_update(skip_locked=True)
Expand All @@ -156,15 +157,15 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
select(InstanceModel)
.where(InstanceModel.id == instance.id)
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
.options(joinedload(InstanceModel.job))
.options(joinedload(InstanceModel.jobs))
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
.execution_options(populate_existing=True)
)
instance = res.unique().scalar_one()
if (
instance.status == InstanceStatus.IDLE
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
and instance.job_id is None
and not instance.jobs
):
await _mark_terminating_if_idle_duration_expired(instance)
if instance.status == InstanceStatus.PENDING:
Expand Down Expand Up @@ -322,6 +323,26 @@ async def _add_remote(instance: InstanceModel) -> None:
)
return

divisible, blocks = is_divisible_into_blocks(
cpu_count=instance_type.resources.cpus,
gpu_count=len(instance_type.resources.gpus),
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
)
if divisible:
instance.total_blocks = blocks
else:
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Cannot split into blocks"
logger.warning(
"Failed to add instance %s: cannot split into blocks",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

region = instance.region
jpd = JobProvisioningData(
backend=BackendType.REMOTE,
Expand Down Expand Up @@ -479,6 +500,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
requirements=requirements,
exclude_not_available=True,
fleet_model=instance.fleet,
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
)

if not offers and should_retry:
Expand Down Expand Up @@ -554,6 +576,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
instance.instance_configuration = instance_configuration.json()
instance.job_provisioning_data = job_provisioning_data.json()
instance.offer = instance_offer.json()
instance.total_blocks = instance_offer.total_blocks
instance.started_at = get_current_datetime()
instance.last_retry_at = get_current_datetime()

Expand Down Expand Up @@ -585,8 +608,8 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
async def _check_instance(instance: InstanceModel) -> None:
if (
instance.status == InstanceStatus.BUSY
and instance.job is not None
and instance.job.status.is_finished()
and instance.jobs
and all(job.status.is_finished() for job in instance.jobs)
):
# A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
instance.status = InstanceStatus.TERMINATING
Expand Down Expand Up @@ -648,9 +671,7 @@ async def _check_instance(instance: InstanceModel) -> None:
instance.unreachable = False

if instance.status == InstanceStatus.PROVISIONING:
instance.status = (
InstanceStatus.IDLE if instance.job_id is None else InstanceStatus.BUSY
)
instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
logger.info(
"Instance %s has switched to %s status",
instance.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def _process_next_running_job():
.limit(1)
.with_for_update(skip_locked=True)
)
job_model = res.scalar()
job_model = res.unique().scalar()
if job_model is None:
return
lockset.add(job_model.id)
Expand All @@ -102,7 +102,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
.options(joinedload(JobModel.instance))
.execution_options(populate_existing=True)
)
job_model = res.scalar_one()
job_model = res.unique().scalar_one()
res = await session.execute(
select(RunModel)
.where(RunModel.id == job_model.run_id)
Expand Down
Loading

0 comments on commit c1206c5

Please sign in to comment.