diff --git a/cli/dstack/_internal/backend/aws/__init__.py b/cli/dstack/_internal/backend/aws/__init__.py index 58656a234..2586f156d 100644 --- a/cli/dstack/_internal/backend/aws/__init__.py +++ b/cli/dstack/_internal/backend/aws/__init__.py @@ -83,6 +83,7 @@ def run_job( self, job: Job, failed_to_start_job_new_status: JobStatus, + project_private_key: str, offer: Optional[InstanceOffer] = None, ): self._logging.create_log_groups_if_not_exist( @@ -90,7 +91,12 @@ def run_job( self.backend_config.bucket_name, job.repo_ref.repo_id, ) - super().run_job(job, failed_to_start_job_new_status, offer=offer) + super().run_job( + job, + failed_to_start_job_new_status, + project_private_key=project_private_key, + offer=offer, + ) def create_run(self, repo_id: str, run_name: Optional[str]) -> str: self._logging.create_log_groups_if_not_exist( diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index 9f2da2e20..c9da9ad58 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -78,6 +78,7 @@ def run_job( self, job: Job, failed_to_start_job_new_status: JobStatus, + project_private_key: str, offer: Optional[InstanceOffer] = None, ): pass @@ -311,6 +312,7 @@ def run_job( self, job: Job, failed_to_start_job_new_status: JobStatus, + project_private_key: str, offer: Optional[InstanceOffer] = None, ): self.predict_build_plan(job) # raises exception on missing build @@ -320,6 +322,7 @@ def run_job( self.secrets_manager(), job, failed_to_start_job_new_status, + project_private_key=project_private_key, offer=offer, ) diff --git a/cli/dstack/_internal/backend/base/gateway.py b/cli/dstack/_internal/backend/base/gateway.py index ad0e61aa9..d3bace9ba 100644 --- a/cli/dstack/_internal/backend/base/gateway.py +++ b/cli/dstack/_internal/backend/base/gateway.py @@ -1,5 +1,7 @@ import re import subprocess +import time +from tempfile import NamedTemporaryFile from typing import List, Optional import pkg_resources @@ -15,7 +17,6 @@ from dstack._internal.core.error import SSHCommandError from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.job import Job -from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH from dstack._internal.utils.common import PathLike, removeprefix from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.interpolator import VariablesInterpolator @@ -59,17 +60,18 @@ def publish( port: int, ssh_key: bytes, secure: bool, + project_private_key: str, user: str = "ubuntu", - id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH, ) -> str: command = ["sudo", "python3", "-", hostname, str(port), f'"{ssh_key.decode().strip()}"'] if secure: command.append("--secure") - with open( - pkg_resources.resource_filename("dstack._internal", "scripts/gateway_publish.py"), "r" - ) as f: + script_path = pkg_resources.resource_filename("dstack._internal", "scripts/gateway_publish.py") + with open(script_path, "r") as script, NamedTemporaryFile("w") as id_rsa: + id_rsa.write(project_private_key) + id_rsa.flush() output = exec_ssh_command( - hostname, command=" ".join(command), user=user, id_rsa=id_rsa, stdin=f + hostname, command=" ".join(command), user=user, id_rsa=id_rsa.name, stdin=script ) return output.decode().strip() @@ -117,7 +119,7 @@ def is_ip_address(hostname: str) -> bool: return re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", hostname) is not None -def setup_service_job(job: Job, secrets_manager: SecretsManager) -> Job: +def setup_service_job(job: Job, secrets_manager: SecretsManager, project_private_key: str) -> Job: job.gateway.hostname = resolve_hostname( secrets_manager, job.repo_ref.repo_id, job.gateway.hostname ) @@ -126,7 +128,11 @@ def setup_service_job(job: Job, secrets_manager: SecretsManager) -> Job: job.gateway.public_port = 443 private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name) job.gateway.sock_path = publish( - job.gateway.hostname, job.gateway.public_port, public_bytes, secure=job.gateway.secure + job.gateway.hostname, + job.gateway.public_port, + public_bytes, + project_private_key=project_private_key, + secure=job.gateway.secure, ) job.gateway.ssh_key = private_bytes.decode() return job diff --git a/cli/dstack/_internal/backend/base/jobs.py b/cli/dstack/_internal/backend/base/jobs.py index e9c2e2e02..4c3744eae 100644 --- a/cli/dstack/_internal/backend/base/jobs.py +++ b/cli/dstack/_internal/backend/base/jobs.py @@ -122,11 +122,14 @@ def run_job( secrets_manager: SecretsManager, job: Job, failed_to_start_job_new_status: JobStatus, + project_private_key: str, offer: Optional[InstanceOffer] = None, ): try: if job.configuration_type == ConfigurationType.SERVICE: - job = gateway.setup_service_job(job, secrets_manager) + job = gateway.setup_service_job( + job, secrets_manager, project_private_key=project_private_key + ) update_job(storage, job) _try_run_job( diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index e4bb6c0f8..15f7ca311 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -72,6 +72,7 @@ def run_job( self, job: Job, failed_to_start_job_new_status: JobStatus, + project_private_key: str, offer: Optional[InstanceOffer] = None, ): self._logging.create_log_groups_if_not_exist( @@ -79,7 +80,12 @@ def run_job( self.backend_config.storage_config.bucket, job.repo_ref.repo_id, ) - super().run_job(job, failed_to_start_job_new_status, offer=offer) + super().run_job( + job, + failed_to_start_job_new_status, + project_private_key=project_private_key, + offer=offer, + ) def create_run(self, repo_id: str, run_name: Optional[str]) -> str: self._logging.create_log_groups_if_not_exist( diff --git a/cli/dstack/_internal/hub/db/models.py b/cli/dstack/_internal/hub/db/models.py index a032c9b35..a1de1d16f 100644 --- a/cli/dstack/_internal/hub/db/models.py +++ b/cli/dstack/_internal/hub/db/models.py @@ -1,6 +1,6 @@ from typing import List -from sqlalchemy import ForeignKey, Integer, MetaData, String +from sqlalchemy import ForeignKey, Integer, MetaData, String, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship constraint_naming_convention = { @@ -60,5 +60,7 @@ class Project(Base): __tablename__ = "projects" name: Mapped[str] = mapped_column(String(50), primary_key=True) + ssh_private_key: Mapped[str] = mapped_column(Text) + ssh_public_key: Mapped[str] = mapped_column(Text) members: Mapped[List[Member]] = relationship(back_populates="project", lazy="selectin") backends: Mapped[List[Backend]] = relationship(back_populates="project", lazy="selectin") diff --git a/cli/dstack/_internal/hub/main.py b/cli/dstack/_internal/hub/main.py index e3796bff4..e2a9f1458 100644 --- a/cli/dstack/_internal/hub/main.py +++ b/cli/dstack/_internal/hub/main.py @@ -57,7 +57,7 @@ async def lifespan(app: FastAPI): "\nTo start using dstack:\n" f"\n 1. Configure one or more clouds at {add_backend_url}." "\n 2. Initialize a repo with `dstack init`." - "\n 3. Define and run a dev enviroment, a task, or a service. For details, see https://dstack.ai/docs/.\n" + "\n 3. Define and run a dev environment, a task, or a service. For details, see https://dstack.ai/docs/.\n" ) yield scheduler.shutdown() diff --git a/cli/dstack/_internal/hub/migration/versions/32e5940896ad_add_ssh_keys_to_project.py b/cli/dstack/_internal/hub/migration/versions/32e5940896ad_add_ssh_keys_to_project.py new file mode 100644 index 000000000..cd62dafc0 --- /dev/null +++ b/cli/dstack/_internal/hub/migration/versions/32e5940896ad_add_ssh_keys_to_project.py @@ -0,0 +1,59 @@ +"""Add ssh keys to project + +Revision ID: 32e5940896ad +Revises: e6df5271c730 +Create Date: 2023-08-17 16:05:15.118951 + +""" +import sqlalchemy as sa +from alembic import op + +from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes + +# revision identifiers, used by Alembic. +revision = "32e5940896ad" +down_revision = "e6df5271c730" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.add_column(sa.Column("ssh_private_key", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("ssh_public_key", sa.Text(), nullable=True)) + + t_project = sa.Table( + "projects", + sa.MetaData(), + sa.Column("name", sa.String(50)), + sa.Column("ssh_private_key", sa.Text()), + sa.Column("ssh_public_key", sa.Text()), + ) + conn = op.get_bind() + projects = conn.execute(sa.select(t_project.c.name)).fetchall() + for (project,) in projects: + private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=f"{project}@dstack") + conn.execute( + t_project.update() + .where(t_project.c.name == project) + .values( + ssh_private_key=private_bytes.decode(), + ssh_public_key=public_bytes.decode(), + ) + ) + + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.alter_column("ssh_private_key", nullable=False) + batch_op.alter_column("ssh_public_key", nullable=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.drop_column("ssh_public_key") + batch_op.drop_column("ssh_private_key") + + # ### end Alembic commands ### diff --git a/cli/dstack/_internal/hub/repository/projects.py b/cli/dstack/_internal/hub/repository/projects.py index 35dfe7dfa..9dbf3ee32 100644 --- a/cli/dstack/_internal/hub/repository/projects.py +++ b/cli/dstack/_internal/hub/repository/projects.py @@ -19,6 +19,7 @@ from dstack._internal.hub.security.utils import ROLE_ADMIN from dstack._internal.hub.services.backends import get_configurator from dstack._internal.hub.utils.common import run_async +from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes class ProjectManager: @@ -34,7 +35,14 @@ async def create( members: List[Member], session: Optional[AsyncSession] = None, ) -> Project: - project = Project(name=project_name) + private_bytes, public_bytes = await run_async( + generate_rsa_key_pair_bytes, f"{project_name}@dstack" + ) + project = Project( + name=project_name, + ssh_private_key=private_bytes.decode(), + ssh_public_key=public_bytes.decode(), + ) await ProjectManager._create(project, session=session) await ProjectManager._add_member( project=project, diff --git a/cli/dstack/_internal/hub/routers/gateways.py b/cli/dstack/_internal/hub/routers/gateways.py index e522adb71..996ad1a30 100644 --- a/cli/dstack/_internal/hub/routers/gateways.py +++ b/cli/dstack/_internal/hub/routers/gateways.py @@ -22,7 +22,7 @@ async def gateways_create(project_name: str, backend_name: str = Body()) -> Gate if backend.name != backend_name: continue try: - return await call_backend(backend.create_gateway, get_hub_ssh_public_key()) + return await call_backend(backend.create_gateway, project.ssh_public_key) except NotImplementedError: pass raise HTTPException( diff --git a/cli/dstack/_internal/hub/routers/runners.py b/cli/dstack/_internal/hub/routers/runners.py index b499bf3f2..78a4a5e2d 100644 --- a/cli/dstack/_internal/hub/routers/runners.py +++ b/cli/dstack/_internal/hub/routers/runners.py @@ -58,7 +58,13 @@ async def run(project_name: str, body: RunRunners): offer.price, ) try: - await call_backend(backend.run_job, body.job, failed_to_start_job_new_status, offer) + await call_backend( + backend.run_job, + body.job, + failed_to_start_job_new_status, + project.ssh_private_key, + offer, + ) return except NoMatchingInstanceError: continue diff --git a/cli/tests/hub/common.py b/cli/tests/hub/common.py index cc86cd0ab..476fc28b2 100644 --- a/cli/tests/hub/common.py +++ b/cli/tests/hub/common.py @@ -23,7 +23,7 @@ async def create_user( async def create_project( name: str = "test_project", ) -> Project: - project = Project(name=name) + project = Project(name=name, ssh_private_key="", ssh_public_key="") await ProjectManager._create(project) return project