Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GPU blocks property #2253

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/dstack/_internal/cli/utils/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def get_fleets_table(
resources = instance.instance_type.resources.pretty_format(include_spot=True)

status = instance.status.value
total_blocks = instance.total_blocks
busy_blocks = instance.busy_blocks
if (
total_blocks is not None
and total_blocks > 1
and total_blocks > busy_blocks
and instance.status == InstanceStatus.BUSY
):
# 1/4 BUSY => 3/4 IDLE
idle_blocks = total_blocks - busy_blocks
status = f"{idle_blocks}/{total_blocks} {InstanceStatus.IDLE.value}"
if (
instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]
and instance.unreachable
Expand Down
19 changes: 15 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def th(s: str) -> str:
InstanceAvailability.BUSY,
}:
availability = offer.availability.value.replace("_", " ").lower()
instance = offer.instance.name
if offer.total_blocks > 1:
instance += f" ({offer.blocks}/{offer.total_blocks})"
offers.add_row(
f"{i}",
offer.backend.replace("remote", "ssh"),
offer.region,
offer.instance.name,
instance,
r.pretty_format(),
"yes" if r.spot else "no",
f"${offer.price:g}",
Expand Down Expand Up @@ -161,13 +164,21 @@ def get_runs_table(
"SUBMITTED": format_date(job.job_submissions[-1].submitted_at),
"ERROR": _get_job_error(job),
}
jpd = job.job_submissions[-1].job_provisioning_data
latest_job_submission = job.job_submissions[-1]
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
resources = jpd.instance_type.resources
instance = jpd.instance_type.name
jrd = latest_job_submission.job_runtime_data
if jrd is not None and jrd.offer is not None:
resources = jrd.offer.instance.resources
if jrd.offer.total_blocks > 1:
instance += f" ({jrd.offer.blocks}/{jrd.offer.total_blocks})"
job_row.update(
{
"BACKEND": f"{jpd.backend.value.replace('remote', 'ssh')} ({jpd.region})",
"INSTANCE": jpd.instance_type.name,
"RESOURCES": jpd.instance_type.resources.pretty_format(include_spot=True),
"INSTANCE": instance,
"RESOURCES": resources.pretty_format(include_spot=True),
"RESERVATION": jpd.reservation,
"PRICE": f"${jpd.price:.4}",
}
Expand Down
26 changes: 26 additions & 0 deletions src/dstack/_internal/core/models/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ class SSHHostParams(CoreModel):
] = None
ssh_key: Optional[SSHKey] = None

blocks: Annotated[
Union[Literal["auto"], int],
Field(
description=(
"The amount of blocks to split the instance into, a number or `auto`."
" `auto` means as many as possible."
" The number of GPUs and CPUs must be divisible by the number of blocks."
" Defaults to `1`, i.e. do not split"
),
ge=1,
),
] = 1

@validator("internal_ip")
def validate_internal_ip(cls, value):
if value is None:
Expand Down Expand Up @@ -142,6 +155,19 @@ class InstanceGroupParams(CoreModel):
Field(description="The resources requirements"),
] = ResourcesSpec()

blocks: Annotated[
Union[Literal["auto"], int],
Field(
description=(
"The amount of blocks to split the instance into, a number or `auto`."
" `auto` means as many as possible."
" The number of GPUs and CPUs must be divisible by the number of blocks."
" Defaults to `1`, i.e. do not split"
),
ge=1,
),
] = 1

backends: Annotated[
Optional[List[BackendType]],
Field(description="The backends to consider for provisioning (e.g., `[aws, gcp]`)"),
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class InstanceOffer(CoreModel):
class InstanceOfferWithAvailability(InstanceOffer):
availability: InstanceAvailability
instance_runtime: InstanceRuntime = InstanceRuntime.SHIM
blocks: int = 1
total_blocks: int = 1


class InstanceStatus(str, Enum):
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/models/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ class Instance(CoreModel):
fleet_name: Optional[str] = None
instance_num: int
pool_name: Optional[str] = None
job_name: Optional[str] = None
job_name: Optional[str] = None # deprecated, always None (instance can have more than one job)
hostname: Optional[str] = None
status: InstanceStatus
unreachable: bool = False
termination_reason: Optional[str] = None
created: datetime.datetime
region: Optional[str] = None
price: Optional[float] = None
total_blocks: Optional[int] = None
busy_blocks: int = 0


class PoolInstances(CoreModel):
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ def get_base_backend(self) -> BackendType:


class JobRuntimeData(CoreModel):
"""
Holds various information only available after the job is submitted, such as:
* offer (depends on the instance)
* volumes used by the job
* resource constraints for container (depend on the instance)
* port mapping (reported by the shim only after the container is started)

Some fields are mutable, for example, `ports` only available when the shim starts
the container.
"""

network_mode: NetworkMode
# GPU, CPU, memory resource shares. None means all available (no limit)
gpu: Optional[int] = None
Expand All @@ -239,6 +250,10 @@ class JobRuntimeData(CoreModel):
# None if data is not yet available (on vm-based backends and ssh instances)
# or not applicable (container-based backends)
ports: Optional[dict[int, int]] = None
# List of volumes used by the job
volume_names: Optional[list[str]] = None # None for backward compalibility
# Virtual shared offer
offer: Optional[InstanceOfferWithAvailability] = None # None for backward compalibility


class ClusterInfo(CoreModel):
Expand Down
37 changes: 29 additions & 8 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
get_create_instance_offers,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.offers import is_divisible_into_blocks
from dstack._internal.server.services.placement import (
get_fleet_placement_groups,
placement_group_model_to_placement_group,
Expand Down Expand Up @@ -133,7 +134,7 @@ async def _process_next_instance():
),
InstanceModel.id.not_in(lockset),
)
.options(lazyload(InstanceModel.job))
.options(lazyload(InstanceModel.jobs))
.order_by(InstanceModel.last_processed_at.asc())
.limit(1)
.with_for_update(skip_locked=True)
Expand All @@ -156,15 +157,15 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
select(InstanceModel)
.where(InstanceModel.id == instance.id)
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
.options(joinedload(InstanceModel.job))
.options(joinedload(InstanceModel.jobs))
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
.execution_options(populate_existing=True)
)
instance = res.unique().scalar_one()
if (
instance.status == InstanceStatus.IDLE
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
and instance.job_id is None
and not instance.jobs
):
await _mark_terminating_if_idle_duration_expired(instance)
if instance.status == InstanceStatus.PENDING:
Expand Down Expand Up @@ -322,6 +323,26 @@ async def _add_remote(instance: InstanceModel) -> None:
)
return

divisible, blocks = is_divisible_into_blocks(
cpu_count=instance_type.resources.cpus,
gpu_count=len(instance_type.resources.gpus),
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
)
if divisible:
instance.total_blocks = blocks
else:
instance.status = InstanceStatus.TERMINATED
instance.termination_reason = "Cannot split into blocks"
logger.warning(
"Failed to add instance %s: cannot split into blocks",
instance.name,
extra={
"instance_name": instance.name,
"instance_status": InstanceStatus.TERMINATED.value,
},
)
return

region = instance.region
jpd = JobProvisioningData(
backend=BackendType.REMOTE,
Expand Down Expand Up @@ -479,6 +500,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
requirements=requirements,
exclude_not_available=True,
fleet_model=instance.fleet,
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
)

if not offers and should_retry:
Expand Down Expand Up @@ -554,6 +576,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
instance.instance_configuration = instance_configuration.json()
instance.job_provisioning_data = job_provisioning_data.json()
instance.offer = instance_offer.json()
instance.total_blocks = instance_offer.total_blocks
instance.started_at = get_current_datetime()
instance.last_retry_at = get_current_datetime()

Expand Down Expand Up @@ -585,8 +608,8 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
async def _check_instance(instance: InstanceModel) -> None:
if (
instance.status == InstanceStatus.BUSY
and instance.job is not None
and instance.job.status.is_finished()
and instance.jobs
and all(job.status.is_finished() for job in instance.jobs)
):
# A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
instance.status = InstanceStatus.TERMINATING
Expand Down Expand Up @@ -648,9 +671,7 @@ async def _check_instance(instance: InstanceModel) -> None:
instance.unreachable = False

if instance.status == InstanceStatus.PROVISIONING:
instance.status = (
InstanceStatus.IDLE if instance.job_id is None else InstanceStatus.BUSY
)
instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
logger.info(
"Instance %s has switched to %s status",
instance.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def _process_next_running_job():
.limit(1)
.with_for_update(skip_locked=True)
)
job_model = res.scalar()
job_model = res.unique().scalar()
if job_model is None:
return
lockset.add(job_model.id)
Expand All @@ -102,7 +102,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
.options(joinedload(JobModel.instance))
.execution_options(populate_existing=True)
)
job_model = res.scalar_one()
job_model = res.unique().scalar_one()
res = await session.execute(
select(RunModel)
.where(RunModel.id == job_model.run_id)
Expand Down
Loading