Skip to content

Commit

Permalink
Final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
noah-paige committed Dec 7, 2023
1 parent cd4fe7f commit 9417ab9
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 107 deletions.
116 changes: 67 additions & 49 deletions backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
)
from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository

from dataall.base.aws.sts import SessionHelper
from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension
from dataall.core.environment.services.environment_service import EnvironmentService
from dataall.base.aws.ec2_client import EC2

logger = logging.getLogger(__name__)

Expand All @@ -36,61 +38,77 @@ def extent(setup: EnvironmentSetup):

if domain.vpcId and domain.subnetIds:
logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain')
vpc_id = domain.vpcId
vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId)
subnet_ids = domain.subnetIds
security_groups = []
else:
logger.info("VPC not specified or not found, Exception. Creating a VPC for SageMaker resources...")
# Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
log_group = logs.LogGroup(
setup,
f'SageMakerStudio{_environment.name}',
log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio',
retention=logs.RetentionDays.ONE_MONTH,
removal_policy=RemovalPolicy.DESTROY,
cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
accountid=_environment.AwsAccountId, region=_environment.region
)
vpc_flow_role = iam.Role(
setup, 'FlowLog',
assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com')
)
vpc = ec2.Vpc(
setup,
"SageMakerVPC",
max_azs=3,
cidr="10.10.0.0/16",
subnet_configuration=[
ec2.SubnetConfiguration(
subnet_type=ec2.SubnetType.PUBLIC,
name="Public",
cidr_mask=24
),
ec2.SubnetConfiguration(
subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT,
name="Private",
cidr_mask=24
),
],
enable_dns_hostnames=True,
enable_dns_support=True,
)
ec2.FlowLog(
setup, "StudioVPCFlowLog",
resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
)
# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
existing_default_vpc = EC2.check_default_vpc_exists(
AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn
)
if existing_default_vpc:
logger.info("Using default VPC for Sagemaker Studio domain")
# Use default VPC - initial configuration (to be migrated)
vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True)
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets]
subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets]
security_groups = []
else:
logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...")
# Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
log_group = logs.LogGroup(
setup,
f'SageMakerStudio{_environment.name}',
log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio',
retention=logs.RetentionDays.ONE_MONTH,
removal_policy=RemovalPolicy.DESTROY,
)
vpc_flow_role = iam.Role(
setup, 'FlowLog',
assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com')
)
vpc = ec2.Vpc(
setup,
"SageMakerVPC",
max_azs=3,
cidr="10.10.0.0/16",
subnet_configuration=[
ec2.SubnetConfiguration(
subnet_type=ec2.SubnetType.PUBLIC,
name="Public",
cidr_mask=24
),
ec2.SubnetConfiguration(
subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT,
name="Private",
cidr_mask=24
),
],
enable_dns_hostnames=True,
enable_dns_support=True,
)
ec2.FlowLog(
setup, "StudioVPCFlowLog",
resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
)
# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
)

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
vpc_id = vpc.vpc_id
vpc_id = vpc.vpc_id

sagemaker_domain_role = iam.Role(
setup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ def update_env(session, environment, **kwargs):
current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled")
domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri)
previous_mlstudio_enabled = True if domain else False
vpcType = domain.vpcType
if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled):
SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri)
return True
elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled):
SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs)
return True
elif current_mlstudio_enabled:
elif current_mlstudio_enabled and vpcType == "unknown":
SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs)
return True
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Environment(Resource, Base):
environmentUri = Column(String, primary_key=True)
AwsAccountId = Column(Boolean)
region = Column(Boolean)
SamlGroupName = Column(String)


class EnvironmentParameter(Base):
Expand Down
111 changes: 54 additions & 57 deletions frontend/src/modules/Environments/components/EnvironmentMLStudio.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export const EnvironmentMLStudio = ({ environment }) => {
<Card>
<CardHeader
action={<RefreshTableMenu refresh={fetchMLStudioDomain} />}
title={<Box>ML Studio Domain</Box>}
title={<Box>ML Studio Domain Information</Box>}
/>
<Divider />
<Box
Expand Down Expand Up @@ -90,62 +90,59 @@ export const EnvironmentMLStudio = ({ environment }) => {
) : (
<Grid container spacing={3}>
<Grid item lg={8} xl={9} xs={12}>
<Card>
<CardHeader title="ML Studio Information" />
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
SageMaker ML Studio Domain Name
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.sagemakerStudioDomainName}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
SageMaker ML Studio Default Execution Role
</Typography>
<Typography color="textPrimary" variant="body2">
arn:aws:iam::{environment.AwsAccountId}:role/
{mlStudioDomain.DefaultDomainRoleName}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain VPC Type
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.vpcType}
</Typography>
</CardContent>
{(mlStudioDomain.vpcId === 'imported' ||
mlStudioDomain.vpcId === 'default') && (
<>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain VPC Id
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.vpcId}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain Subnet Ids
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.subnetIds?.map((subnet) => (
<Chip
sx={{ mr: 0.5, mb: 0.5 }}
key={subnet}
label={subnet}
variant="outlined"
/>
))}
</Typography>
</CardContent>
</>
)}
</Card>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
SageMaker ML Studio Domain Name
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.sagemakerStudioDomainName}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
SageMaker ML Studio Default Execution Role
</Typography>
<Typography color="textPrimary" variant="body2">
arn:aws:iam::{environment.AwsAccountId}:role/
{mlStudioDomain.DefaultDomainRoleName}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain VPC Type
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.vpcType}
</Typography>
</CardContent>
{(mlStudioDomain.vpcType === 'imported' ||
mlStudioDomain.vpcType === 'default') && (
<>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain VPC Id
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.vpcId}
</Typography>
</CardContent>
<CardContent>
<Typography color="textSecondary" variant="subtitle2">
Domain Subnet Ids
</Typography>
<Typography color="textPrimary" variant="body2">
{mlStudioDomain.subnetIds?.map((subnet) => (
<Chip
sx={{ mr: 0.5, mb: 0.5 }}
key={subnet}
label={subnet}
variant="outlined"
/>
))}
</Typography>
</CardContent>
</>
)}
</Grid>
</Grid>
)}
Expand Down

0 comments on commit 9417ab9

Please sign in to comment.