Skip to content

Commit

Permalink
Store project's ssh keys in DB (#658)
Browse files Browse the repository at this point in the history
* Add ssh keys to every project

* Use a project ssh key for gateway

* Fix server tests
  • Loading branch information
Egor-S authored Aug 17, 2023
1 parent 3ac32ff commit c9e1510
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 17 deletions.
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,20 @@ def run_job(
self,
job: Job,
failed_to_start_job_new_status: JobStatus,
project_private_key: str,
offer: Optional[InstanceOffer] = None,
):
self._logging.create_log_groups_if_not_exist(
aws_utils.get_logs_client(self._session),
self.backend_config.bucket_name,
job.repo_ref.repo_id,
)
super().run_job(job, failed_to_start_job_new_status, offer=offer)
super().run_job(
job,
failed_to_start_job_new_status,
project_private_key=project_private_key,
offer=offer,
)

def create_run(self, repo_id: str, run_name: Optional[str]) -> str:
self._logging.create_log_groups_if_not_exist(
Expand Down
3 changes: 3 additions & 0 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run_job(
self,
job: Job,
failed_to_start_job_new_status: JobStatus,
project_private_key: str,
offer: Optional[InstanceOffer] = None,
):
pass
Expand Down Expand Up @@ -311,6 +312,7 @@ def run_job(
self,
job: Job,
failed_to_start_job_new_status: JobStatus,
project_private_key: str,
offer: Optional[InstanceOffer] = None,
):
self.predict_build_plan(job) # raises exception on missing build
Expand All @@ -320,6 +322,7 @@ def run_job(
self.secrets_manager(),
job,
failed_to_start_job_new_status,
project_private_key=project_private_key,
offer=offer,
)

Expand Down
22 changes: 14 additions & 8 deletions cli/dstack/_internal/backend/base/gateway.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import subprocess
import time
from tempfile import NamedTemporaryFile
from typing import List, Optional

import pkg_resources
Expand All @@ -15,7 +17,6 @@
from dstack._internal.core.error import SSHCommandError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.core.job import Job
from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH
from dstack._internal.utils.common import PathLike, removeprefix
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.interpolator import VariablesInterpolator
Expand Down Expand Up @@ -59,17 +60,18 @@ def publish(
port: int,
ssh_key: bytes,
secure: bool,
project_private_key: str,
user: str = "ubuntu",
id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH,
) -> str:
command = ["sudo", "python3", "-", hostname, str(port), f'"{ssh_key.decode().strip()}"']
if secure:
command.append("--secure")
with open(
pkg_resources.resource_filename("dstack._internal", "scripts/gateway_publish.py"), "r"
) as f:
script_path = pkg_resources.resource_filename("dstack._internal", "scripts/gateway_publish.py")
with open(script_path, "r") as script, NamedTemporaryFile("w") as id_rsa:
id_rsa.write(project_private_key)
id_rsa.flush()
output = exec_ssh_command(
hostname, command=" ".join(command), user=user, id_rsa=id_rsa, stdin=f
hostname, command=" ".join(command), user=user, id_rsa=id_rsa.name, stdin=script
)
return output.decode().strip()

Expand Down Expand Up @@ -117,7 +119,7 @@ def is_ip_address(hostname: str) -> bool:
return re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", hostname) is not None


def setup_service_job(job: Job, secrets_manager: SecretsManager) -> Job:
def setup_service_job(job: Job, secrets_manager: SecretsManager, project_private_key: str) -> Job:
job.gateway.hostname = resolve_hostname(
secrets_manager, job.repo_ref.repo_id, job.gateway.hostname
)
Expand All @@ -126,7 +128,11 @@ def setup_service_job(job: Job, secrets_manager: SecretsManager) -> Job:
job.gateway.public_port = 443
private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name)
job.gateway.sock_path = publish(
job.gateway.hostname, job.gateway.public_port, public_bytes, secure=job.gateway.secure
job.gateway.hostname,
job.gateway.public_port,
public_bytes,
project_private_key=project_private_key,
secure=job.gateway.secure,
)
job.gateway.ssh_key = private_bytes.decode()
return job
5 changes: 4 additions & 1 deletion cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,14 @@ def run_job(
secrets_manager: SecretsManager,
job: Job,
failed_to_start_job_new_status: JobStatus,
project_private_key: str,
offer: Optional[InstanceOffer] = None,
):
try:
if job.configuration_type == ConfigurationType.SERVICE:
job = gateway.setup_service_job(job, secrets_manager)
job = gateway.setup_service_job(
job, secrets_manager, project_private_key=project_private_key
)
update_job(storage, job)

_try_run_job(
Expand Down
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/lambdalabs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,20 @@ def run_job(
self,
job: Job,
failed_to_start_job_new_status: JobStatus,
project_private_key: str,
offer: Optional[InstanceOffer] = None,
):
self._logging.create_log_groups_if_not_exist(
aws_utils.get_logs_client(self._session),
self.backend_config.storage_config.bucket,
job.repo_ref.repo_id,
)
super().run_job(job, failed_to_start_job_new_status, offer=offer)
super().run_job(
job,
failed_to_start_job_new_status,
project_private_key=project_private_key,
offer=offer,
)

def create_run(self, repo_id: str, run_name: Optional[str]) -> str:
self._logging.create_log_groups_if_not_exist(
Expand Down
4 changes: 3 additions & 1 deletion cli/dstack/_internal/hub/db/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from sqlalchemy import ForeignKey, Integer, MetaData, String
from sqlalchemy import ForeignKey, Integer, MetaData, String, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

constraint_naming_convention = {
Expand Down Expand Up @@ -60,5 +60,7 @@ class Project(Base):
__tablename__ = "projects"

name: Mapped[str] = mapped_column(String(50), primary_key=True)
ssh_private_key: Mapped[str] = mapped_column(Text)
ssh_public_key: Mapped[str] = mapped_column(Text)
members: Mapped[List[Member]] = relationship(back_populates="project", lazy="selectin")
backends: Mapped[List[Backend]] = relationship(back_populates="project", lazy="selectin")
2 changes: 1 addition & 1 deletion cli/dstack/_internal/hub/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def lifespan(app: FastAPI):
"\nTo start using dstack:\n"
f"\n 1. Configure one or more clouds at {add_backend_url}."
"\n 2. Initialize a repo with `dstack init`."
"\n 3. Define and run a dev enviroment, a task, or a service. For details, see https://dstack.ai/docs/.\n"
"\n 3. Define and run a dev environment, a task, or a service. For details, see https://dstack.ai/docs/.\n"
)
yield
scheduler.shutdown()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Add ssh keys to project
Revision ID: 32e5940896ad
Revises: e6df5271c730
Create Date: 2023-08-17 16:05:15.118951
"""
import sqlalchemy as sa
from alembic import op

from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes

# revision identifiers, used by Alembic.
revision = "32e5940896ad"
down_revision = "e6df5271c730"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.add_column(sa.Column("ssh_private_key", sa.Text(), nullable=True))
batch_op.add_column(sa.Column("ssh_public_key", sa.Text(), nullable=True))

t_project = sa.Table(
"projects",
sa.MetaData(),
sa.Column("name", sa.String(50)),
sa.Column("ssh_private_key", sa.Text()),
sa.Column("ssh_public_key", sa.Text()),
)
conn = op.get_bind()
projects = conn.execute(sa.select(t_project.c.name)).fetchall()
for (project,) in projects:
private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=f"{project}@dstack")
conn.execute(
t_project.update()
.where(t_project.c.name == project)
.values(
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
)
)

with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.alter_column("ssh_private_key", nullable=False)
batch_op.alter_column("ssh_public_key", nullable=False)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.drop_column("ssh_public_key")
batch_op.drop_column("ssh_private_key")

# ### end Alembic commands ###
10 changes: 9 additions & 1 deletion cli/dstack/_internal/hub/repository/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dstack._internal.hub.security.utils import ROLE_ADMIN
from dstack._internal.hub.services.backends import get_configurator
from dstack._internal.hub.utils.common import run_async
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes


class ProjectManager:
Expand All @@ -34,7 +35,14 @@ async def create(
members: List[Member],
session: Optional[AsyncSession] = None,
) -> Project:
project = Project(name=project_name)
private_bytes, public_bytes = await run_async(
generate_rsa_key_pair_bytes, f"{project_name}@dstack"
)
project = Project(
name=project_name,
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
)
await ProjectManager._create(project, session=session)
await ProjectManager._add_member(
project=project,
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/hub/routers/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def gateways_create(project_name: str, backend_name: str = Body()) -> Gate
if backend.name != backend_name:
continue
try:
return await call_backend(backend.create_gateway, get_hub_ssh_public_key())
return await call_backend(backend.create_gateway, project.ssh_public_key)
except NotImplementedError:
pass
raise HTTPException(
Expand Down
8 changes: 7 additions & 1 deletion cli/dstack/_internal/hub/routers/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ async def run(project_name: str, body: RunRunners):
offer.price,
)
try:
await call_backend(backend.run_job, body.job, failed_to_start_job_new_status, offer)
await call_backend(
backend.run_job,
body.job,
failed_to_start_job_new_status,
project.ssh_private_key,
offer,
)
return
except NoMatchingInstanceError:
continue
Expand Down
2 changes: 1 addition & 1 deletion cli/tests/hub/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def create_user(
async def create_project(
name: str = "test_project",
) -> Project:
project = Project(name=name)
project = Project(name=name, ssh_private_key="", ssh_public_key="")
await ProjectManager._create(project)
return project

Expand Down

0 comments on commit c9e1510

Please sign in to comment.