Skip to content

Commit

Permalink
Add created_at to projects and users (#1857)
Browse files Browse the repository at this point in the history
* WIP

* Add created_at to projects and users
  • Loading branch information
r4victor authored Oct 17, 2024
1 parent a171ba6 commit b3d6a0f
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/models/projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from datetime import datetime
from typing import List, Optional

from pydantic import UUID4

Expand All @@ -21,5 +22,6 @@ class Project(CoreModel):
project_id: UUID4
project_name: str
owner: User
created_at: Optional[datetime] = None
backends: List[BackendInfo]
members: List[Member]
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from datetime import datetime
from typing import Optional

from pydantic import UUID4
Expand All @@ -24,6 +25,7 @@ class UserPermissions(CoreModel):
class User(CoreModel):
id: UUID4
username: str
created_at: Optional[datetime] = None
global_role: GlobalRole
email: Optional[str]
active: bool
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Add created_at to UserModel and ProjectModel
Revision ID: afbc600ff2b2
Revises: c20626d03cfb
Create Date: 2024-10-16 14:31:49.040804
"""

import uuid
from datetime import timedelta

import sqlalchemy as sa
import sqlalchemy_utils
from alembic import op

import dstack._internal.server.models
from dstack._internal.utils.common import get_current_datetime

# revision identifiers, used by Alembic.
revision = "afbc600ff2b2"
down_revision = "c20626d03cfb"
branch_labels = None
depends_on = None


users_table = sa.Table(
"users",
sa.MetaData(),
# partial description - only columns affected by this migration
sa.Column("id", sqlalchemy_utils.UUIDType(binary=False), primary_key=True, default=uuid.uuid4),
sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True),
)


projects_table = sa.Table(
"projects",
sa.MetaData(),
# partial description - only columns affected by this migration
sa.Column("id", sqlalchemy_utils.UUIDType(binary=False), primary_key=True, default=uuid.uuid4),
sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True),
)


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.add_column(
sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True)
)
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(
sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True)
)

# Set created_at on existing rows.
# The absolute value does not matter since it cannot be recovered.
# Just ensure that created_at order matches the insertion order.
# SELECT should fetch the rows in the insertion order when there are no additional conditions.
last_created_at = get_current_datetime()

users_update_params = []
users = op.get_bind().execute(sa.select(users_table))
for i, row in enumerate(reversed(users.all())):
created_at = last_created_at - timedelta(seconds=i)
users_update_params.append({"_id": row.id, "created_at": created_at})
update_stmt = (
users_table.update()
.where(users_table.c.id == sa.bindparam("_id"))
.values(created_at=sa.bindparam("created_at"))
)
if users_update_params:
op.get_bind().execute(update_stmt, users_update_params)

projects_update_params = []
projects = op.get_bind().execute(sa.select(projects_table))
for i, row in enumerate(reversed(projects.all())):
created_at = last_created_at - timedelta(seconds=i)
projects_update_params.append({"_id": row.id, "created_at": created_at})
update_stmt = (
projects_table.update()
.where(projects_table.c.id == sa.bindparam("_id"))
.values(created_at=sa.bindparam("created_at"))
)
if projects_update_params:
op.get_bind().execute(update_stmt, projects_update_params)

with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.alter_column("created_at", nullable=False)
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.alter_column("created_at", nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("created_at")

with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.drop_column("created_at")

# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class UserModel(BaseModel):
UUIDType(binary=False), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(String(50), unique=True)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
token: Mapped[DecryptedString] = mapped_column(EncryptedString(200), unique=True)
# token_hash is needed for fast search by token when stored token is encrypted
token_hash: Mapped[str] = mapped_column(String(2000), unique=True)
Expand All @@ -173,6 +174,7 @@ class ProjectModel(BaseModel):
UUIDType(binary=False), primary_key=True, default=uuid.uuid4
)
name: Mapped[str] = mapped_column(String(50), unique=True)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)

owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from datetime import timezone
from typing import Awaitable, Callable, List, Optional, Tuple

from sqlalchemy import delete, select, update
Expand Down Expand Up @@ -55,6 +56,7 @@ async def list_user_projects(
projects = await list_project_models(session=session)
else:
projects = await list_user_project_models(session=session, user=user)
projects = sorted(projects, key=lambda p: p.created_at)
return [
project_model_to_project(p, include_backends=False, include_members=False)
for p in projects
Expand Down Expand Up @@ -393,6 +395,7 @@ def project_model_to_project(
project_id=project_model.id,
project_name=project_model.name,
owner=users.user_model_to_user(project_model.owner),
created_at=project_model.created_at.replace(tzinfo=timezone.utc),
backends=backends,
members=members,
)
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import uuid
from datetime import timezone
from typing import Awaitable, Callable, List, Optional, Tuple

from sqlalchemy import delete, select, update
Expand Down Expand Up @@ -48,6 +49,7 @@ async def list_all_users(
) -> List[User]:
res = await session.execute(select(UserModel))
user_models = res.scalars().all()
user_models = sorted(user_models, key=lambda u: u.created_at)
return [user_model_to_user(u) for u in user_models]


Expand Down Expand Up @@ -184,6 +186,7 @@ def user_model_to_user(user_model: UserModel) -> User:
return User(
id=user_model.id,
username=user_model.name,
created_at=user_model.created_at.replace(tzinfo=timezone.utc),
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
Expand All @@ -195,6 +198,7 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds:
return UserWithCreds(
id=user_model.id,
username=user_model.name,
created_at=user_model.created_at.replace(tzinfo=timezone.utc),
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_auth_headers(token: Union[DecryptedString, str]) -> Dict:
async def create_user(
session: AsyncSession,
name: str = "test_user",
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
global_role: GlobalRole = GlobalRole.ADMIN,
token: Optional[str] = None,
email: Optional[str] = None,
Expand All @@ -91,6 +92,7 @@ async def create_user(
token = str(uuid.uuid4())
user = UserModel(
name=name,
created_at=created_at,
global_role=global_role,
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
Expand All @@ -106,6 +108,7 @@ async def create_project(
session: AsyncSession,
owner: Optional[UserModel] = None,
name: str = "test_project",
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
ssh_private_key: str = "",
ssh_public_key: str = "",
) -> ProjectModel:
Expand All @@ -114,6 +117,7 @@ async def create_project(
project = ProjectModel(
name=name,
owner_id=owner.id,
created_at=created_at,
ssh_private_key=ssh_private_key,
ssh_public_key=ssh_public_key,
)
Expand Down
Loading

0 comments on commit b3d6a0f

Please sign in to comment.