diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 4070524a0..7304d92e1 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -15,9 +15,11 @@ from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.models.resources import Range, ResourcesSpec from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeMountPoint +from dstack._internal.settings import FeatureFlags CommandsList = List[str] ValidPort = conint(gt=0, le=65536) +SERVICE_HTTPS_DEFAULT = True class RunConfigurationType(str, Enum): @@ -203,7 +205,14 @@ class ServiceConfigurationParams(CoreModel): Optional[AnyModel], Field(description="Mapping of the model for the OpenAI-compatible endpoint"), ] = None - https: Annotated[bool, Field(description="Enable HTTPS")] = True + https: Annotated[ + bool, + Field( + description="Enable HTTPS" + if not FeatureFlags.PROXY + else "Enable HTTPS if running with a gateway" + ), + ] = SERVICE_HTTPS_DEFAULT auth: Annotated[bool, Field(description="Enable the authorization")] = True replicas: Annotated[ Union[conint(ge=1), constr(regex=r"^[0-9]+..[1-9][0-9]*$"), Range[int]], diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 115e4b0ba..af429528a 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from enum import Enum from typing import Any, Dict, List, Optional, Type +from urllib.parse import urlparse from pydantic import UUID4, Field, root_validator from typing_extensions import Annotated @@ -29,7 +30,7 @@ from dstack._internal.core.models.repos import AnyRunRepoData from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.utils import common as common_utils -from dstack._internal.utils.common import format_pretty_duration +from dstack._internal.utils.common import concat_url_path, format_pretty_duration class AppSpec(CoreModel): @@ -304,10 +305,22 @@ class ServiceModelSpec(CoreModel): class ServiceSpec(CoreModel): - url: str + url: Annotated[str, Field(description="Full URL or path relative to dstack-server's base URL")] model: Optional[ServiceModelSpec] = None options: Dict[str, Any] = {} + def full_url(self, server_base_url: str) -> str: + service_url = urlparse(self.url) + if service_url.scheme and service_url.netloc: + return self.url + server_url = urlparse(server_base_url) + service_url = service_url._replace( + scheme=server_url.scheme or "http", + netloc=server_url.netloc, + path=concat_url_path(server_url.path, service_url.path), + ) + return service_url.geturl() + class RunStatus(str, Enum): PENDING = "pending" diff --git a/src/dstack/_internal/core/services/logs.py b/src/dstack/_internal/core/services/logs.py index 16f62cc25..cafcc24af 100644 --- a/src/dstack/_internal/core/services/logs.py +++ b/src/dstack/_internal/core/services/logs.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional from dstack._internal.core.models.runs import AppSpec +from dstack._internal.utils.common import concat_url_path class URLReplacer: @@ -12,12 +13,14 @@ def __init__( ports: Dict[int, int], hostname: str, secure: bool, + path_prefix: str = "", ip_address: Optional[str] = None, ): self.app_specs = {app_spec.port: app_spec for app_spec in app_specs} self.ports = ports self.hostname = hostname self.secure = secure + self.path_prefix = path_prefix.encode() hosts = ["localhost", "0.0.0.0", "127.0.0.1"] if ip_address and ip_address not in hosts: @@ -43,6 +46,7 @@ def _replace_url(self, match: re.Match) -> bytes: url = url._replace( scheme=("https" if self.secure else "http").encode(), netloc=(self.hostname if omit_port else f"{self.hostname}:{local_port}").encode(), + path=concat_url_path(self.path_prefix, url.path), query=urllib.parse.urlencode(qs).encode(), ) return url.geturl() diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index fa4536e81..705d1e5b9 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -191,7 +191,9 @@ class ProjectModel(BaseModel): default_gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("gateways.id", use_alter=True, ondelete="SET NULL"), nullable=True ) - default_gateway: Mapped["GatewayModel"] = relationship(foreign_keys=[default_gateway_id]) + default_gateway: Mapped[Optional["GatewayModel"]] = relationship( + foreign_keys=[default_gateway_id] + ) default_pool_id: Mapped[Optional[UUIDType]] = mapped_column( ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 37ce46b0c..729cd834c 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -28,6 +28,7 @@ SSHError, ) from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT from dstack._internal.core.models.gateways import ( Gateway, GatewayComputeConfiguration, @@ -68,6 +69,7 @@ gather_map_async, run_async, ) +from dstack._internal.settings import FeatureFlags from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger @@ -352,20 +354,23 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> async def register_service(session: AsyncSession, run_model: RunModel): - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - - # TODO(egor-s): allow to configure gateway name - gateway_name: Optional[str] = None - if gateway_name is None: - gateway = run_model.project.default_gateway - if gateway is None: - raise ResourceNotExistsError("Default gateway is not set") + gateway = run_model.project.default_gateway + + if gateway is not None: + service_spec = await _register_service_in_gateway(session, run_model, gateway) + run_model.gateway = gateway + elif FeatureFlags.PROXY: + service_spec = _register_service_in_server(run_model) else: - gateway = await get_project_gateway_model_by_name( - session=session, project=run_model.project, name=gateway_name - ) - if gateway is None: - raise ResourceNotExistsError("Gateway does not exist") + raise ResourceNotExistsError("Default gateway is not set") + run_model.service_spec = service_spec.json() + + +async def _register_service_in_gateway( + session: AsyncSession, run_model: RunModel, gateway: GatewayModel +) -> ServiceSpec: + run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec) + if gateway.gateway_compute is None: raise ServerClientError("Gateway has no instance associated with it") @@ -396,9 +401,6 @@ async def register_service(session: AsyncSession, run_model: RunModel): ) service_spec.options = get_service_options(run_spec.configuration) - run_model.gateway = gateway - run_model.service_spec = service_spec.json() - conn = await get_or_add_gateway_connection(session, gateway.id) try: logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) @@ -420,10 +422,36 @@ async def register_service(session: AsyncSession, run_model: RunModel): logger.debug("Gateway request failed", exc_info=True) raise GatewayError(f"Gateway is not working: {e!r}") + return service_spec + + +def _register_service_in_server(run_model: RunModel) -> ServiceSpec: + run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec) + if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: + # Note: if the user sets `https: `, it will be ignored silently + # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted + raise ServerClientError( + "The `https` configuration property is not applicable when running services without a gateway. " + "Please configure a gateway or remove the `https` property from the service configuration" + ) + if run_spec.configuration.model is not None: + raise ServerClientError( + "Model mappings are not yet supported when running services without a gateway. " + "Please configure a gateway or remove the `model` property from the service configuration" + ) + if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: + raise ServerClientError( + "Auto-scaling is not yet supported when running services without a gateway. " + "Please configure a gateway or set `replicas` to a fixed value in the service configuration" + ) + return ServiceSpec(url=f"/services/{run_model.project.name}/{run_model.run_name}/") + async def register_replica( - session: AsyncSession, gateway_id: uuid.UUID, run: Run, job_model: JobModel + session: AsyncSession, gateway_id: Optional[uuid.UUID], run: Run, job_model: JobModel ): + if gateway_id is None: # in-server proxy + return conn = await get_or_add_gateway_connection(session, gateway_id) job_submission = jobs_services.job_model_to_job_submission(job_model) try: @@ -440,10 +468,7 @@ async def register_replica( async def unregister_service(session: AsyncSession, run_model: RunModel): - if run_model.gateway_id is None: - logger.error( - "Failed to unregister service. run_model.gateway_id is None for %s", run_model.run_name - ) + if run_model.gateway_id is None: # in-server proxy return conn = await get_or_add_gateway_connection(session, run_model.gateway_id) res = await session.execute( @@ -474,7 +499,7 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel): ) run_model = res.unique().scalar_one() if run_model.gateway_id is None: - # The run is not a service + # not a service or served by in-server proxy return conn = await get_or_add_gateway_connection(session, run_model.gateway_id) diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index aa863c8f5..e5d792855 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -9,7 +9,7 @@ from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import JobProvisioningData, JobStatus, RunSpec from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Replica, Service -from dstack._internal.server.models import JobModel, ProjectModel +from dstack._internal.server.models import JobModel, ProjectModel, RunModel class DBProxyRepo(BaseProxyRepo): @@ -27,8 +27,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic res = await self.session.execute( select(JobModel) .join(JobModel.project) + .join(JobModel.run) .where( ProjectModel.name == project_name, + RunModel.gateway_id.is_(None), JobModel.run_name == run_name, JobModel.status == JobStatus.RUNNING, JobModel.job_num == 0, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 0a1e73d22..70d2d60bb 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -714,7 +714,7 @@ async def process_terminating_run(session: AsyncSession, run: RunModel): job.last_processed_at = common_utils.get_current_datetime() if unfinished_jobs_count == 0: - if run.gateway_id is not None: + if run.service_spec is not None: try: await gateways.unregister_service(session, run) except Exception as e: diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index f2e56a3b0..53cd8fbf5 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -233,3 +233,23 @@ def get_or_error(v: Optional[T]) -> T: def batched(seq: Iterable[T], n: int) -> Iterable[List[T]]: it = iter(seq) return iter(lambda: list(itertools.islice(it, n)), []) + + +StrT = TypeVar("StrT", str, bytes) + + +def lstrip_one(string: StrT, substring: StrT) -> StrT: + """Remove at most one occurrence of `substring` at the start of `string`""" + return string[len(substring) :] if substring and string.startswith(substring) else string + + +def rstrip_one(string: StrT, substring: StrT) -> StrT: + """Remove at most one occurrence of `substring` at the end of `string`""" + return string[: -len(substring)] if substring and string.endswith(substring) else string + + +def concat_url_path(a: StrT, b: StrT) -> StrT: + if not b: + return a + sep = "/" if isinstance(a, str) else b"/" + return rstrip_one(a, sep) + sep + lstrip_one(b, sep) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 84ec63c0e..a27b1948d 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -100,7 +100,7 @@ def hostname(self) -> str: def service_url(self) -> str: if self._run.run_spec.configuration.type != "service": raise ValueError("The run is not a service") - return self._run.service.url + return self._run.service.full_url(server_base_url=self._api_client.base_url) def _attached_logs( self, @@ -124,25 +124,28 @@ def ws_thread(): ) threading.Thread(target=ws_thread).start() - ports = self.ports hostname = "127.0.0.1" secure = False + ports = self.ports + path_prefix = "" if self._run.service is not None: - url = urlparse(self._run.service.url) + url = urlparse(self.service_url) + hostname = url.hostname + secure = url.scheme == "https" service_port = url.port if service_port is None: - service_port = 443 if self._run.run_spec.configuration.https else 80 + service_port = 443 if secure else 80 ports = { **ports, self._run.run_spec.configuration.port.container_port: service_port, } - hostname = url.hostname - secure = url.scheme == "https" + path_prefix = url.path replace_urls = URLReplacer( ports=ports, app_specs=self._run.jobs[0].job_spec.app_specs, hostname=hostname, secure=secure, + path_prefix=path_prefix, ip_address=self.hostname, ) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 49262aaae..e142f7f9b 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -56,6 +56,10 @@ def __init__(self, base_url: str, token: str): if client_api_version is not None: self._s.headers.update({"X-API-VERSION": client_api_version}) + @property + def base_url(self) -> str: + return self._base_url + @property def users(self) -> UsersAPIClient: return UsersAPIClient(self._request) diff --git a/src/tests/_internal/core/models/test_runs.py b/src/tests/_internal/core/models/test_runs.py index 576eba61a..51b206bcf 100644 --- a/src/tests/_internal/core/models/test_runs.py +++ b/src/tests/_internal/core/models/test_runs.py @@ -1,8 +1,11 @@ +import pytest + from dstack._internal.core.models.runs import ( JobStatus, JobTerminationReason, RunStatus, RunTerminationReason, + ServiceSpec, ) @@ -22,3 +25,35 @@ def test_job_termination_reason_to_status_works_with_all_enum_varians(): for job_termination_reason in JobTerminationReason: job_status = job_termination_reason.to_status() assert isinstance(job_status, JobStatus) + + +def test_service_spec_full_url_from_full_url(): + spec = ServiceSpec(url="https://service.gateway.dstack.example") + assert ( + spec.full_url(server_base_url="http://localhost:3000") + == "https://service.gateway.dstack.example" + ) + + +@pytest.mark.parametrize( + ("server_url", "service_path", "service_url"), + [ + ( + "http://localhost:3000", + "/services/main/service/", + "http://localhost:3000/services/main/service/", + ), + ( + "http://localhost:3000/", + "/services/main/service/", + "http://localhost:3000/services/main/service/", + ), + ( + "http://localhost:3000/prefix", + "/services/main/service/", + "http://localhost:3000/prefix/services/main/service/", + ), + ], +) +def test_service_spec_full_url_from_path(server_url, service_path, service_url): + assert ServiceSpec(url=service_path).full_url(server_base_url=server_url) == service_url diff --git a/src/tests/_internal/core/services/test_logs.py b/src/tests/_internal/core/services/test_logs.py index c1eca72d3..8f0bef0c1 100644 --- a/src/tests/_internal/core/services/test_logs.py +++ b/src/tests/_internal/core/services/test_logs.py @@ -125,3 +125,17 @@ def test_omit_https_default_port(self): ports={8000: 443}, app_specs=[], hostname="secure.host.com", secure=True ) assert replacer(b"http://0.0.0.0:8000/qwerty") == b"https://secure.host.com/qwerty" + + def test_in_server_proxy(self): + replacer = URLReplacer( + ports={8888: 3000}, + app_specs=[], + hostname="0.0.0.0", + secure=False, + path_prefix="/services/main/service/", + ) + assert replacer(b"http://0.0.0.0:8888") == b"http://0.0.0.0:3000/services/main/service/" + assert ( + replacer(b"http://0.0.0.0:8888/qwerty") + == b"http://0.0.0.0:3000/services/main/service/qwerty" + ) diff --git a/src/tests/_internal/utils/test_common.py b/src/tests/_internal/utils/test_common.py index 6e77d2f0a..c38979bc5 100644 --- a/src/tests/_internal/utils/test_common.py +++ b/src/tests/_internal/utils/test_common.py @@ -4,7 +4,14 @@ import pytest from freezegun import freeze_time -from dstack._internal.utils.common import parse_memory, pretty_date, split_chunks +from dstack._internal.utils.common import ( + concat_url_path, + lstrip_one, + parse_memory, + pretty_date, + rstrip_one, + split_chunks, +) @freeze_time(datetime(2023, 10, 4, 12, 0, tzinfo=timezone.utc)) @@ -109,3 +116,51 @@ def test_split_chunks( def test_raises_on_invalid_chunk_size(self, chunk_size: int) -> None: with pytest.raises(ValueError): list(split_chunks([1, 2, 3], chunk_size)) + + +@pytest.mark.parametrize( + ("string", "substring", "result"), + [ + ("ababc", "ab", "abc"), + ("ababc", "bc", "ababc"), + ("ababc", "", "ababc"), + ("", "a", ""), + ("", "", ""), + ], +) +def test_lstrip_one(string: str, substring: str, result: str) -> None: + assert lstrip_one(string, substring) == result + assert lstrip_one(string.encode(), substring.encode()) == result.encode() + + +@pytest.mark.parametrize( + ("string", "substring", "result"), + [ + ("abcbc", "bc", "abc"), + ("abcbc", "ab", "abcbc"), + ("abcbc", "", "abcbc"), + ("", "a", ""), + ("", "", ""), + ], +) +def test_rstrip_one(string: str, substring: str, result: str) -> None: + assert rstrip_one(string, substring) == result + assert rstrip_one(string.encode(), substring.encode()) == result.encode() + + +@pytest.mark.parametrize( + ("a", "b", "result"), + [ + ("/a/b", "c/d", "/a/b/c/d"), + ("/a/b/", "/c/d", "/a/b/c/d"), + ("/a/b//", "//c/d", "/a/b///c/d"), + ("/a", "", "/a"), + ("/a", "/", "/a/"), + ("", "a", "/a"), + ("/", "a", "/a"), + ("", "", ""), + ], +) +def test_concat_url_path(a: str, b: str, result: str) -> None: + assert concat_url_path(a, b) == result + assert concat_url_path(a.encode(), b.encode()) == result.encode()