Skip to content

Commit

Permalink
Fix Migration Script for New Deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
noah-paige committed Dec 12, 2023
1 parent 94c93d9 commit 578e480
Showing 1 changed file with 62 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,77 +61,73 @@ def upgrade():
1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information
"""
try:
envname = os.getenv('envname', 'local')
engine = get_engine(envname=envname).engine

bind = op.get_bind()
session = orm.Session(bind=bind)

if has_table('sagemaker_studio_domain', engine):
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
session.add(domain)
session.add(domain)
session.flush()
session.commit()
print("Fill of sagemaker_studio_domain table is done")
Expand Down

0 comments on commit 578e480

Please sign in to comment.