Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow running services without a gateway #1869

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeMountPoint
from dstack._internal.settings import FeatureFlags

CommandsList = List[str]
ValidPort = conint(gt=0, le=65536)
SERVICE_HTTPS_DEFAULT = True


class RunConfigurationType(str, Enum):
Expand Down Expand Up @@ -203,7 +205,14 @@ class ServiceConfigurationParams(CoreModel):
Optional[AnyModel],
Field(description="Mapping of the model for the OpenAI-compatible endpoint"),
] = None
https: Annotated[bool, Field(description="Enable HTTPS")] = True
https: Annotated[
bool,
Field(
description="Enable HTTPS"
if not FeatureFlags.PROXY
else "Enable HTTPS if running with a gateway"
),
] = SERVICE_HTTPS_DEFAULT
auth: Annotated[bool, Field(description="Enable the authorization")] = True
replicas: Annotated[
Union[conint(ge=1), constr(regex=r"^[0-9]+..[1-9][0-9]*$"), Range[int]],
Expand Down
17 changes: 15 additions & 2 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from urllib.parse import urlparse

from pydantic import UUID4, Field, root_validator
from typing_extensions import Annotated
Expand Down Expand Up @@ -29,7 +30,7 @@
from dstack._internal.core.models.repos import AnyRunRepoData
from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.utils import common as common_utils
from dstack._internal.utils.common import format_pretty_duration
from dstack._internal.utils.common import concat_url_path, format_pretty_duration


class AppSpec(CoreModel):
Expand Down Expand Up @@ -304,10 +305,22 @@ class ServiceModelSpec(CoreModel):


class ServiceSpec(CoreModel):
url: str
url: Annotated[str, Field(description="Full URL or path relative to dstack-server's base URL")]
model: Optional[ServiceModelSpec] = None
options: Dict[str, Any] = {}

def full_url(self, server_base_url: str) -> str:
service_url = urlparse(self.url)
if service_url.scheme and service_url.netloc:
return self.url
server_url = urlparse(server_base_url)
service_url = service_url._replace(
scheme=server_url.scheme or "http",
netloc=server_url.netloc,
path=concat_url_path(server_url.path, service_url.path),
)
return service_url.geturl()


class RunStatus(str, Enum):
PENDING = "pending"
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/services/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional

from dstack._internal.core.models.runs import AppSpec
from dstack._internal.utils.common import concat_url_path


class URLReplacer:
Expand All @@ -12,12 +13,14 @@ def __init__(
ports: Dict[int, int],
hostname: str,
secure: bool,
path_prefix: str = "",
ip_address: Optional[str] = None,
):
self.app_specs = {app_spec.port: app_spec for app_spec in app_specs}
self.ports = ports
self.hostname = hostname
self.secure = secure
self.path_prefix = path_prefix.encode()

hosts = ["localhost", "0.0.0.0", "127.0.0.1"]
if ip_address and ip_address not in hosts:
Expand All @@ -43,6 +46,7 @@ def _replace_url(self, match: re.Match) -> bytes:
url = url._replace(
scheme=("https" if self.secure else "http").encode(),
netloc=(self.hostname if omit_port else f"{self.hostname}:{local_port}").encode(),
path=concat_url_path(self.path_prefix, url.path),
query=urllib.parse.urlencode(qs).encode(),
)
return url.geturl()
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ class ProjectModel(BaseModel):
default_gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column(
ForeignKey("gateways.id", use_alter=True, ondelete="SET NULL"), nullable=True
)
default_gateway: Mapped["GatewayModel"] = relationship(foreign_keys=[default_gateway_id])
default_gateway: Mapped[Optional["GatewayModel"]] = relationship(
foreign_keys=[default_gateway_id]
)

default_pool_id: Mapped[Optional[UUIDType]] = mapped_column(
ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True
Expand Down
69 changes: 47 additions & 22 deletions src/dstack/_internal/server/services/gateways/__init__.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The services/gateways directory needs rearranging - with the new terms, only some of this logic is gateway-specific, while some is proxy-specific. Will do in another PR.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SSHError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT
from dstack._internal.core.models.gateways import (
Gateway,
GatewayComputeConfiguration,
Expand Down Expand Up @@ -68,6 +69,7 @@
gather_map_async,
run_async,
)
from dstack._internal.settings import FeatureFlags
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -352,20 +354,23 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->


async def register_service(session: AsyncSession, run_model: RunModel):
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)

# TODO(egor-s): allow to configure gateway name
gateway_name: Optional[str] = None
if gateway_name is None:
gateway = run_model.project.default_gateway
if gateway is None:
raise ResourceNotExistsError("Default gateway is not set")
gateway = run_model.project.default_gateway

if gateway is not None:
service_spec = await _register_service_in_gateway(session, run_model, gateway)
run_model.gateway = gateway
elif FeatureFlags.PROXY:
service_spec = _register_service_in_server(run_model)
else:
gateway = await get_project_gateway_model_by_name(
session=session, project=run_model.project, name=gateway_name
)
if gateway is None:
raise ResourceNotExistsError("Gateway does not exist")
raise ResourceNotExistsError("Default gateway is not set")
run_model.service_spec = service_spec.json()


async def _register_service_in_gateway(
session: AsyncSession, run_model: RunModel, gateway: GatewayModel
) -> ServiceSpec:
run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec)

if gateway.gateway_compute is None:
raise ServerClientError("Gateway has no instance associated with it")

Expand Down Expand Up @@ -396,9 +401,6 @@ async def register_service(session: AsyncSession, run_model: RunModel):
)
service_spec.options = get_service_options(run_spec.configuration)

run_model.gateway = gateway
run_model.service_spec = service_spec.json()

conn = await get_or_add_gateway_connection(session, gateway.id)
try:
logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
Expand All @@ -420,10 +422,36 @@ async def register_service(session: AsyncSession, run_model: RunModel):
logger.debug("Gateway request failed", exc_info=True)
raise GatewayError(f"Gateway is not working: {e!r}")

return service_spec


def _register_service_in_server(run_model: RunModel) -> ServiceSpec:
run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec)
if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT:
# Note: if the user sets `https: <default-value>`, it will be ignored silently
# TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted
raise ServerClientError(
"The `https` configuration property is not applicable when running services without a gateway. "
"Please configure a gateway or remove the `https` property from the service configuration"
)
if run_spec.configuration.model is not None:
raise ServerClientError(
"Model mappings are not yet supported when running services without a gateway. "
"Please configure a gateway or remove the `model` property from the service configuration"
)
if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max:
raise ServerClientError(
"Auto-scaling is not yet supported when running services without a gateway. "
"Please configure a gateway or set `replicas` to a fixed value in the service configuration"
)
return ServiceSpec(url=f"/services/{run_model.project.name}/{run_model.run_name}/")


async def register_replica(
session: AsyncSession, gateway_id: uuid.UUID, run: Run, job_model: JobModel
session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel
):
if gateway_id is None: # in-server proxy
return
conn = await get_or_add_gateway_connection(session, gateway_id)
job_submission = jobs_services.job_model_to_job_submission(job_model)
try:
Expand All @@ -440,10 +468,7 @@ async def register_replica(


async def unregister_service(session: AsyncSession, run_model: RunModel):
if run_model.gateway_id is None:
logger.error(
"Failed to unregister service. run_model.gateway_id is None for %s", run_model.run_name
)
if run_model.gateway_id is None: # in-server proxy
return
conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
res = await session.execute(
Expand Down Expand Up @@ -474,7 +499,7 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel):
)
run_model = res.unique().scalar_one()
if run_model.gateway_id is None:
# The run is not a service
# not a service or served by in-server proxy
return

conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData, JobStatus, RunSpec
from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service
from dstack._internal.server.models import JobModel, ProjectModel
from dstack._internal.server.models import JobModel, ProjectModel, RunModel


class DBProxyRepo(BaseProxyRepo):
Expand All @@ -27,8 +27,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
res = await self.session.execute(
select(JobModel)
.join(JobModel.project)
.join(JobModel.run)
.where(
ProjectModel.name == project_name,
RunModel.gateway_id.is_(None),
JobModel.run_name == run_name,
JobModel.status == JobStatus.RUNNING,
JobModel.job_num == 0,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ async def process_terminating_run(session: AsyncSession, run: RunModel):
job.last_processed_at = common_utils.get_current_datetime()

if unfinished_jobs_count == 0:
if run.gateway_id is not None:
if run.service_spec is not None:
try:
await gateways.unregister_service(session, run)
except Exception as e:
Expand Down
20 changes: 20 additions & 0 deletions src/dstack/_internal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,23 @@ def get_or_error(v: Optional[T]) -> T:
def batched(seq: Iterable[T], n: int) -> Iterable[List[T]]:
it = iter(seq)
return iter(lambda: list(itertools.islice(it, n)), [])


StrT = TypeVar("StrT", str, bytes)


def lstrip_one(string: StrT, substring: StrT) -> StrT:
"""Remove at most one occurrence of `substring` at the start of `string`"""
return string[len(substring) :] if substring and string.startswith(substring) else string


def rstrip_one(string: StrT, substring: StrT) -> StrT:
"""Remove at most one occurrence of `substring` at the end of `string`"""
return string[: -len(substring)] if substring and string.endswith(substring) else string


def concat_url_path(a: StrT, b: StrT) -> StrT:
if not b:
return a
sep = "/" if isinstance(a, str) else b"/"
return rstrip_one(a, sep) + sep + lstrip_one(b, sep)
15 changes: 9 additions & 6 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def hostname(self) -> str:
def service_url(self) -> str:
if self._run.run_spec.configuration.type != "service":
raise ValueError("The run is not a service")
return self._run.service.url
return self._run.service.full_url(server_base_url=self._api_client.base_url)

def _attached_logs(
self,
Expand All @@ -124,25 +124,28 @@ def ws_thread():
)
threading.Thread(target=ws_thread).start()

ports = self.ports
hostname = "127.0.0.1"
secure = False
ports = self.ports
path_prefix = ""
if self._run.service is not None:
url = urlparse(self._run.service.url)
url = urlparse(self.service_url)
hostname = url.hostname
secure = url.scheme == "https"
service_port = url.port
if service_port is None:
service_port = 443 if self._run.run_spec.configuration.https else 80
service_port = 443 if secure else 80
ports = {
**ports,
self._run.run_spec.configuration.port.container_port: service_port,
}
hostname = url.hostname
secure = url.scheme == "https"
path_prefix = url.path
replace_urls = URLReplacer(
ports=ports,
app_specs=self._run.jobs[0].job_spec.app_specs,
hostname=hostname,
secure=secure,
path_prefix=path_prefix,
ip_address=self.hostname,
)

Expand Down
4 changes: 4 additions & 0 deletions src/dstack/api/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __init__(self, base_url: str, token: str):
if client_api_version is not None:
self._s.headers.update({"X-API-VERSION": client_api_version})

@property
def base_url(self) -> str:
return self._base_url

@property
def users(self) -> UsersAPIClient:
return UsersAPIClient(self._request)
Expand Down
35 changes: 35 additions & 0 deletions src/tests/_internal/core/models/test_runs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest

from dstack._internal.core.models.runs import (
JobStatus,
JobTerminationReason,
RunStatus,
RunTerminationReason,
ServiceSpec,
)


Expand All @@ -22,3 +25,35 @@ def test_job_termination_reason_to_status_works_with_all_enum_varians():
for job_termination_reason in JobTerminationReason:
job_status = job_termination_reason.to_status()
assert isinstance(job_status, JobStatus)


def test_service_spec_full_url_from_full_url():
spec = ServiceSpec(url="https://service.gateway.dstack.example")
assert (
spec.full_url(server_base_url="http://localhost:3000")
== "https://service.gateway.dstack.example"
)


@pytest.mark.parametrize(
("server_url", "service_path", "service_url"),
[
(
"http://localhost:3000",
"/services/main/service/",
"http://localhost:3000/services/main/service/",
),
(
"http://localhost:3000/",
"/services/main/service/",
"http://localhost:3000/services/main/service/",
),
(
"http://localhost:3000/prefix",
"/services/main/service/",
"http://localhost:3000/prefix/services/main/service/",
),
],
)
def test_service_spec_full_url_from_path(server_url, service_path, service_url):
assert ServiceSpec(url=service_path).full_url(server_base_url=server_url) == service_url
Loading