From 9417ab9d020becf133e5a804b27e73ef66355f62 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 14:35:58 -0500 Subject: [PATCH] Final fixes --- .../mlstudio/cdk/mlstudio_extension.py | 116 ++++++++++-------- .../mlstudio/services/mlstudio_service.py | 3 +- ...f5de322f_update_sagemaker_studio_domain.py | 1 + .../components/EnvironmentMLStudio.js | 111 ++++++++--------- 4 files changed, 124 insertions(+), 107 deletions(-) diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 495cad19b..1a0da209b 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -14,8 +14,10 @@ ) from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) @@ -36,61 +38,77 @@ def extent(setup: EnvironmentSetup): if domain.vpcId and domain.subnetIds: logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') - vpc_id = domain.vpcId + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) subnet_ids = domain.subnetIds security_groups = [] else: - logger.info("VPC not specified or not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=_environment.AwsAccountId, region=_environment.region ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) - ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", - security_group_name=domain.sagemakerStudioDomainName, + existing_default_vpc = EC2.check_default_vpc_exists( + AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn ) + if existing_default_vpc: + logger.info("Using default VPC for Sagemaker Studio domain") + # Use default VPC - initial configuration (to be migrated) + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + security_groups = [] + else: + logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, + ) + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, + ) + + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - vpc_id = vpc.vpc_id + vpc_id = vpc.vpc_id sagemaker_domain_role = iam.Role( setup, diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 729c3971e..75328d530 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -75,13 +75,14 @@ def update_env(session, environment, **kwargs): current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) previous_mlstudio_enabled = True if domain else False + vpcType = domain.vpcType if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) return True elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) return True - elif current_mlstudio_enabled: + elif current_mlstudio_enabled and vpcType == "unknown": SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) return True return False diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 298a481ce..28bf39957 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -29,6 +29,7 @@ class Environment(Resource, Base): environmentUri = Column(String, primary_key=True) AwsAccountId = Column(Boolean) region = Column(Boolean) + SamlGroupName = Column(String) class EnvironmentParameter(Base): diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 79427116f..44dac97b9 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -61,7 +61,7 @@ export const EnvironmentMLStudio = ({ environment }) => { } - title={ML Studio Domain} + title={ML Studio Domain Information} /> { ) : ( - - - - - SageMaker ML Studio Domain Name - - - {mlStudioDomain.sagemakerStudioDomainName} - - - - - SageMaker ML Studio Default Execution Role - - - arn:aws:iam::{environment.AwsAccountId}:role/ - {mlStudioDomain.DefaultDomainRoleName} - - - - - Domain VPC Type - - - {mlStudioDomain.vpcType} - - - {(mlStudioDomain.vpcId === 'imported' || - mlStudioDomain.vpcId === 'default') && ( - <> - - - Domain VPC Id - - - {mlStudioDomain.vpcId} - - - - - Domain Subnet Ids - - - {mlStudioDomain.subnetIds?.map((subnet) => ( - - ))} - - - - )} - + + + SageMaker ML Studio Domain Name + + + {mlStudioDomain.sagemakerStudioDomainName} + + + + + SageMaker ML Studio Default Execution Role + + + arn:aws:iam::{environment.AwsAccountId}:role/ + {mlStudioDomain.DefaultDomainRoleName} + + + + + Domain VPC Type + + + {mlStudioDomain.vpcType} + + + {(mlStudioDomain.vpcType === 'imported' || + mlStudioDomain.vpcType === 'default') && ( + <> + + + Domain VPC Id + + + {mlStudioDomain.vpcId} + + + + + Domain Subnet Ids + + + {mlStudioDomain.subnetIds?.map((subnet) => ( + + ))} + + + + )} )}