From 045f701edc417eec74e65158f50ee9b5c42059a9 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Tue, 21 Nov 2023 11:54:44 -0500 Subject: [PATCH 01/38] Start work on byo vpc ml studio --- .../core/environment/api/input_types.py | 10 ++-- .../views/EnvironmentCreateForm.js | 55 +++++++++++++++++-- frontend/yarn.lock | 23 +++++--- 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 13f0d6df7..767b5812b 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -28,9 +28,8 @@ gql.Argument('description', gql.String), gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), + gql.Argument('mlStudioVPCId', gql.String), + gql.Argument('mlStudioSubnetId', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) @@ -45,9 +44,8 @@ gql.Argument('description', gql.String), gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), + gql.Argument('mlStudioVPCId', gql.String), + gql.Argument('mlStudioSubnetId', gql.ArrayType(gql.String)), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) ], diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index d7c5eb38d..a803784b9 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -482,9 +482,11 @@ const EnvironmentCreateForm = (props) => { dashboardsEnabled: isModuleEnabled(ModuleNames.DASHBOARDS), notebooksEnabled: isModuleEnabled(ModuleNames.NOTEBOOKS), mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO), - pipelinesEnabled: isModuleEnabled(ModuleNames.PIPELINES), + pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', - resourcePrefix: 'dataall' + resourcePrefix: 'dataall', + mlStudioVPCId: '', + mlStudioSubnetId: '' }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -508,9 +510,8 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - privateSubnetIds: Yup.array().nullable(), - publicSubnetIds: Yup.array().nullable(), - vpcId: Yup.string().nullable(), + mlStudioSubnetId: Yup.array().nullable(), + mlStudioVPCId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() @@ -860,6 +861,50 @@ const EnvironmentCreateForm = (props) => { variant="outlined" /> + {values.mlStudiosEnabled && ( + <> + + + + + + + + )} {errors.submit && ( diff --git a/frontend/yarn.lock b/frontend/yarn.lock index e68c1974e..8a2e6c66d 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -7,7 +7,7 @@ resolved "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz" integrity sha512-1Yjs2SvM8TflER/OD3cOjhWWOZb58A2t7wpE2S9XfBYTiIl+XFhQG2bjy4Pu1I+EAlCNUzRDYDdFwFYUKvXcIA== -"@adobe/css-tools@4.3.1", "@adobe/css-tools@^4.0.1": +"@adobe/css-tools@^4.0.1": version "4.3.1" resolved "https://registry.yarnpkg.com/@adobe/css-tools/-/css-tools-4.3.1.tgz#abfccb8ca78075a2b6187345c26243c1a0842f28" integrity sha512-/62yikz7NLScCGAAST5SHdnjaDJQBDq0M2muyRTpf2VQhw6StBg2ALiu73zSJQ4fMVLA+0uBhBHAle7Wg+2kSg== @@ -3074,7 +3074,7 @@ "@babel/parser" "^7.22.5" "@babel/types" "^7.22.5" -"@babel/traverse@7.23.2", "@babel/traverse@^7.22.5", "@babel/traverse@^7.22.6", "@babel/traverse@^7.22.8", "@babel/traverse@^7.7.2": +"@babel/traverse@^7.22.5", "@babel/traverse@^7.22.6", "@babel/traverse@^7.22.8", "@babel/traverse@^7.7.2": version "7.23.2" resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8" integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw== @@ -5694,7 +5694,7 @@ bonjour-service@^1.0.11: fast-deep-equal "^3.1.3" multicast-dns "^7.2.5" -boolbase@^1.0.0: +boolbase@^1.0.0, boolbase@~1.0.0: version "1.0.0" resolved "https://registry.npmjs.org/boolbase/-/boolbase-1.0.0.tgz" integrity sha512-JZOSA7Mo9sNGB8+UjSgzdLtokWAky1zbztM3WRLCbZ70/3cTANmQmOdR7y2g+J0e2WXywy1yS468tY+IruqEww== @@ -9469,10 +9469,10 @@ merge2@^1.3.0, merge2@^1.4.1: resolved "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz" integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== -merge@2.1.1, merge@^1.2.0: - version "2.1.1" - resolved "https://registry.yarnpkg.com/merge/-/merge-2.1.1.tgz#59ef4bf7e0b3e879186436e8481c06a6c162ca98" - integrity sha512-jz+Cfrg9GWOZbQAnDQ4hlVnQky+341Yk5ru8bZSe6sIDTCIg8n9i/u7hSQGSVOF3C7lH6mGtqjkiT9G4wFLL0w== +merge@^1.2.0: + version "1.2.1" + resolved "https://registry.yarnpkg.com/merge/-/merge-1.2.1.tgz#38bebf80c3220a8a487b6fcfb3941bb11720c145" + integrity sha512-VjFo4P5Whtj4vsLzsYBu5ayHhoHJ0UqNm7ibvShmbmoz7tGi0vXaoJbGdB+GmDMLUdg8DpQXEIeVDAe8MaABvQ== methods@~1.1.2: version "1.1.2" @@ -9684,7 +9684,14 @@ nprogress@^0.2.0: resolved "https://registry.npmjs.org/nprogress/-/nprogress-0.2.0.tgz" integrity sha512-I19aIingLgR1fmhftnbWWO3dXc0hSxqHQHQb3H8m+K3TnEn/iSeTZZOyvKXWqQESMwuUVnatlCnZdLBZZt2VSA== -nth-check@^1.0.2, nth-check@^2.0.1: +nth-check@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/nth-check/-/nth-check-1.0.2.tgz#b2bd295c37e3dd58a3bf0700376663ba4d9cf05c" + integrity sha512-WeBOdju8SnzPN5vTUJYxYUxLeXpCaVP5i5e0LF8fg7WORF2Wd7wFX/pk0tYZk7s8T+J7VLy0Da6J1+wCT0AtHg== + dependencies: + boolbase "~1.0.0" + +nth-check@^2.0.1: version "2.1.1" resolved "https://registry.npmjs.org/nth-check/-/nth-check-2.1.1.tgz" integrity sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w== From 86de12324bef1e7b41a7e4f7ef6f2d846cb97dbd Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Fri, 24 Nov 2023 10:13:56 -0500 Subject: [PATCH 02/38] Add Payload --- .../dataall/core/environment/api/resolvers.py | 6 ++++- .../services/environment_service.py | 23 ------------------- .../views/EnvironmentCreateForm.js | 1 + 3 files changed, 6 insertions(+), 24 deletions(-) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index c42c4895a..666ddfd02 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -75,7 +75,7 @@ def check_environment(context: Context, source, account_id, region): return cdk_role_name -def create_environment(context: Context, source, input=None): +def create_environment(context: Context, source, input={}): if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups: raise exceptions.UnauthorizedOperation( action=permissions.LINK_ENVIRONMENT, @@ -99,6 +99,10 @@ def create_environment(context: Context, source, input=None): target_type='environment', target_uri=env.environmentUri, target_label=env.label, + payload={ + 'mlstudio_vpc_id': input.get('mlStudioVPCId', None), + 'mlstudio_vpc_id': input.get('mlStudioSubnetId', None), + }, ) stack_helper.deploy_stack(targetUri=env.environmentUri) env.userRoleInEnvironment = EnvironmentPermission.Owner.value diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index db3cbf050..2862bfe41 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -98,29 +98,6 @@ def create_environment(session, uri, data=None): env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn'] env.EnvironmentDefaultIAMRoleImported = True - if data.get('vpcId'): - vpc = Vpc( - environmentUri=env.environmentUri, - region=env.region, - AwsAccountId=env.AwsAccountId, - VpcId=data.get('vpcId'), - privateSubnetIds=data.get('privateSubnetIds', []), - publicSubnetIds=data.get('publicSubnetIds', []), - SamlGroupName=data['SamlGroupName'], - owner=context.username, - label=f"{env.name}-{data.get('vpcId')}", - name=f"{env.name}-{data.get('vpcId')}", - default=True, - ) - session.add(vpc) - session.commit() - ResourcePolicy.attach_resource_policy( - session=session, - group=data['SamlGroupName'], - permissions=permissions.NETWORK_ALL, - resource_uri=vpc.vpcUri, - resource_type=Vpc.__name__, - ) env_group = EnvironmentGroup( environmentUri=env.environmentUri, groupUri=data['SamlGroupName'], diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index a803784b9..5cfb1ac88 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -863,6 +863,7 @@ const EnvironmentCreateForm = (props) => { {values.mlStudiosEnabled && ( <> + Date: Tue, 28 Nov 2023 17:35:08 -0500 Subject: [PATCH 03/38] Add backend logic deploy mlstudio custom vpc --- backend/dataall/base/aws/vpc.py | 46 ++++++ backend/dataall/base/cdkproxy/app.py | 6 +- .../dataall/base/cdkproxy/cdk_cli_wrapper.py | 3 +- .../core/environment/api/input_types.py | 4 +- .../dataall/core/environment/api/resolvers.py | 24 +++- .../core/environment/cdk/environment_stack.py | 3 +- .../mlstudio/cdk/mlstudio_extension.py | 133 ++++++++++-------- .../views/EnvironmentCreateForm.js | 31 ++-- .../Environments/views/EnvironmentEditForm.js | 52 +++++++ 9 files changed, 219 insertions(+), 83 deletions(-) create mode 100644 backend/dataall/base/aws/vpc.py diff --git a/backend/dataall/base/aws/vpc.py b/backend/dataall/base/aws/vpc.py new file mode 100644 index 000000000..80875c578 --- /dev/null +++ b/backend/dataall/base/aws/vpc.py @@ -0,0 +1,46 @@ +import logging + +from botocore.exceptions import ClientError + +from .sts import SessionHelper + +log = logging.getLogger(__name__) + + +class VPCManager: + def __init__(self): + pass + + @staticmethod + def client(AwsAccountId, region, role=None): + session = SessionHelper.remote_session(accountid=AwsAccountId, role=role) + return session.client('ec2', region_name=region) + + @staticmethod + def check_vpc_exists( AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + try: + ec2 = VPCManager.client(AwsAccountId=AwsAccountId, region=region, role=role) + response = ec2.describe_vpcs(VpcIds=[vpc_id]) + except ClientError as e: + log.exception(f'VPC Id {vpc_id} Not Found: {e}') + raise Exception(f'VPCNotFound: {vpc_id}') + + try: + if subnet_ids: + response = ec2.describe_subnets( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [vpc_id] + }, + ], + SubnetIds=subnet_ids + ) + except ClientError as e: + log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') + raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') + + if not subnet_ids or len(response['Subnets']) != len(subnet_ids): + raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') + + \ No newline at end of file diff --git a/backend/dataall/base/cdkproxy/app.py b/backend/dataall/base/cdkproxy/app.py index 080295841..d20a01648 100644 --- a/backend/dataall/base/cdkproxy/app.py +++ b/backend/dataall/base/cdkproxy/app.py @@ -48,10 +48,10 @@ def create(): logger.info(f' **kwargs: {_data}') if _data: data = json.loads(_data) - # logger.info(f" Kwargs: {_data}") + logger.info(f" Kwargs: {_data}") else: data = {} - # logger.info(f" Kwargs: None provided") + logger.info(f" Kwargs: None provided") # Creating CDK target environment env = Environment(account=account, region=region) @@ -60,7 +60,7 @@ def create(): tbl = tabulate(table, headers=['Setting', 'Value']) # , tablefmt="fancy_grid") logger.info(tbl) - instanciate_stack(stack_name, app, appid, env=env, target_uri=target_uri) + instanciate_stack(stack_name, app, appid, env=env, target_uri=target_uri, payload=data) app.synth() diff --git a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py index e12454ba1..467854a60 100644 --- a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py +++ b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py @@ -6,6 +6,7 @@ import ast import logging import os +import json import subprocess import sys from abc import abstractmethod @@ -164,7 +165,7 @@ def deploy_cdk_stack(engine: Engine, stackid: str, app_path: str = None, path: s '-c', f"target_uri='{stack.targetUri}'", '-c', - "data='{}'", + f"data='{json.dumps(stack.payload)}'", # skips synth step when no changes apply '--app', f'"{sys.executable} {app_path}"', diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 767b5812b..28b6078b6 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -29,7 +29,7 @@ gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetId', gql.ArrayType(gql.String)), + gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) @@ -45,7 +45,7 @@ gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetId', gql.ArrayType(gql.String)), + gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) ], diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 666ddfd02..05f525828 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -8,6 +8,7 @@ from sqlalchemy import and_, exc from dataall.base.aws.iam import IAM +from dataall.base.aws.vpc import VPCManager from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper from dataall.base.utils import Parameter @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs): return True if ssm_param == "True" else False -def check_environment(context: Context, source, account_id, region): +def check_environment(context: Context, source, account_id, region, data): """ Checks necessary resources for environment deployment. - Check CDKToolkit exists in Account assuming cdk_look_up_role - Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False @@ -71,6 +72,21 @@ def check_environment(context: Context, source, account_id, region): action='CHECK_PIVOT_ROLE', message='Pivot Role has not been created in the Environment AWS Account', ) + mlStudioEnabled = None + for parameter in data.get("parameters", []): + if parameter['key'] == 'mlStudiosEnabled': + mlStudioEnabled = parameter['value'] + + if mlStudioEnabled == 'true' and data.get("mlStudioVPCId", None): + log.info("Check if ML Studio VPC Exists in the Account") + + VPCManager.check_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + vpc_id=data.get("mlStudioVPCId", None), + subnet_ids=data.get('mlStudioSubnetIds', []), + ) return cdk_role_name @@ -85,8 +101,10 @@ def create_environment(context: Context, source, input={}): with context.engine.scoped_session() as session: cdk_role_name = check_environment(context, source, account_id=input.get('AwsAccountId'), - region=input.get('region') + region=input.get('region'), + data=input ) + input['cdk_role_name'] = cdk_role_name env = EnvironmentService.create_environment( session=session, @@ -101,7 +119,7 @@ def create_environment(context: Context, source, input={}): target_label=env.label, payload={ 'mlstudio_vpc_id': input.get('mlStudioVPCId', None), - 'mlstudio_vpc_id': input.get('mlStudioSubnetId', None), + 'mlstudio_subnet_ids': input.get('mlStudioSubnetIds', []), }, ) stack_helper.deploy_stack(targetUri=env.environmentUri) diff --git a/backend/dataall/core/environment/cdk/environment_stack.py b/backend/dataall/core/environment/cdk/environment_stack.py index ce0c67215..b7657b6d2 100644 --- a/backend/dataall/core/environment/cdk/environment_stack.py +++ b/backend/dataall/core/environment/cdk/environment_stack.py @@ -105,7 +105,7 @@ def get_environment_admins_group(engine, environment: Environment) -> [Environme group_uri=environment.SamlGroupName, ) - def __init__(self, scope, id, target_uri: str = None, **kwargs): + def __init__(self, scope, id, target_uri: str = None, payload: dict = {}, **kwargs): super().__init__( scope, id, @@ -117,6 +117,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): **kwargs, ) # Read input + self.payload = payload self.target_uri = target_uri self.pivot_role_name = SessionHelper.get_delegation_role_name() self.external_id = SessionHelper.get_external_id_secret() diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index fe9040ab9..0112faaf8 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -36,70 +36,85 @@ def extent(setup: EnvironmentSetup): sagemaker_principals = [setup.default_role] + setup.group_roles logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=_environment.AwsAccountId, region=_environment.region - ) - existing_default_vpc = EC2.check_default_vpc_exists( - AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn - ) - if existing_default_vpc: - logger.info("Using default VPC for Sagemaker Studio domain") - # Use default VPC - initial configuration (to be migrated) - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] - subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + + existing_vpc_id = None + existing_subnet_ids = None + if setup.payload: + existing_vpc_id = setup.payload.get('mlstudio_vpc_id', None) + existing_subnet_ids = setup.payload.get('mlstudio_subnet_ids', []) + logger.info(f'VPC ID = {existing_vpc_id}') + logger.info(f'Subnet IDs = {existing_subnet_ids}') + + if existing_vpc_id and existing_subnet_ids: + logger.info(f'Using VPC {existing_vpc_id} and subnets {existing_subnet_ids} for SageMaker Studio domain') + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=existing_vpc_id) + subnet_ids = existing_subnet_ids security_groups = [] else: - logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, - ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=_environment.AwsAccountId, region=_environment.region ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", + existing_default_vpc = EC2.check_default_vpc_exists( + AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn ) + if existing_default_vpc: + logger.info("Using default VPC for Sagemaker Studio domain") + # Use default VPC - initial configuration (to be migrated) + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + security_groups = [] + else: + logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, + ) + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + ) - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] vpc_id = vpc.vpc_id diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 5cfb1ac88..e892274df 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -179,6 +179,8 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, + mlStudioVPCId: values.mlStudioVPCId, + mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -486,7 +488,7 @@ const EnvironmentCreateForm = (props) => { EnvironmentDefaultIAMRoleArn: '', resourcePrefix: 'dataall', mlStudioVPCId: '', - mlStudioSubnetId: '' + mlStudioSubnetIds: [] }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -510,7 +512,10 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - mlStudioSubnetId: Yup.array().nullable(), + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId',{ + is: (value) => !!value, + then: Yup.array().min(1).required('At least 1 Subnet Id required if VPC Id specified') + }), mlStudioVPCId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() @@ -884,24 +889,22 @@ const EnvironmentCreateForm = (props) => { /> - { + setFieldValue('mlStudioSubnetIds', [...chip]); + }} /> diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index caa5d8441..85665568e 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -81,6 +81,8 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, + mlStudioVPCId: values.mlStudioVPCId, + mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -213,6 +215,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], + mlStudioVPCId: '', + mlStudioSubnetIds: [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -226,6 +230,11 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId',{ + is: (value) => !!value, + then: Yup.array().min(1).required('At least 1 Subnet Id required if VPC Id specified') + }), + mlStudioVPCId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -381,6 +390,49 @@ const EnvironmentEditForm = (props) => { variant="outlined" /> + {values.mlStudiosEnabled && ( + <> + + + + + + { + setFieldValue('mlStudioSubnetIds', [...chip]); + }} + /> + + + )} {isAnyEnvironmentModuleEnabled() && ( From b717f25dceafae940b6f54cc5b86cdefb95d2608 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Tue, 28 Nov 2023 17:40:00 -0500 Subject: [PATCH 04/38] lint --- backend/dataall/base/aws/vpc.py | 4 +--- backend/dataall/base/cdkproxy/app.py | 2 +- .../dataall/core/environment/api/resolvers.py | 2 +- .../views/EnvironmentCreateForm.js | 18 +++++++++++++----- .../Environments/views/EnvironmentEditForm.js | 19 ++++++++++++++----- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/backend/dataall/base/aws/vpc.py b/backend/dataall/base/aws/vpc.py index 80875c578..d543cac8d 100644 --- a/backend/dataall/base/aws/vpc.py +++ b/backend/dataall/base/aws/vpc.py @@ -17,7 +17,7 @@ def client(AwsAccountId, region, role=None): return session.client('ec2', region_name=region) @staticmethod - def check_vpc_exists( AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): try: ec2 = VPCManager.client(AwsAccountId=AwsAccountId, region=region, role=role) response = ec2.describe_vpcs(VpcIds=[vpc_id]) @@ -42,5 +42,3 @@ def check_vpc_exists( AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): if not subnet_ids or len(response['Subnets']) != len(subnet_ids): raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') - - \ No newline at end of file diff --git a/backend/dataall/base/cdkproxy/app.py b/backend/dataall/base/cdkproxy/app.py index d20a01648..b02760e06 100644 --- a/backend/dataall/base/cdkproxy/app.py +++ b/backend/dataall/base/cdkproxy/app.py @@ -51,7 +51,7 @@ def create(): logger.info(f" Kwargs: {_data}") else: data = {} - logger.info(f" Kwargs: None provided") + logger.info(" Kwargs: None provided") # Creating CDK target environment env = Environment(account=account, region=region) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 05f525828..205edd5ac 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -84,7 +84,7 @@ def check_environment(context: Context, source, account_id, region, data): AwsAccountId=account_id, region=region, role=cdk_look_up_role_arn, - vpc_id=data.get("mlStudioVPCId", None), + vpc_id=data.get("mlStudioVPCId", None), subnet_ids=data.get('mlStudioSubnetIds', []), ) diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index e892274df..451aea7c2 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -512,9 +512,13 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId',{ + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { is: (value) => !!value, - then: Yup.array().min(1).required('At least 1 Subnet Id required if VPC Id specified') + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) }), mlStudioVPCId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), @@ -877,10 +881,12 @@ const EnvironmentCreateForm = (props) => { name="mlStudioVPCId" fullWidth error={Boolean( - touched.mlStudioVPCId && errors.mlStudioVPCId + touched.mlStudioVPCId && + errors.mlStudioVPCId )} helperText={ - touched.mlStudioVPCId && errors.mlStudioVPCId + touched.mlStudioVPCId && + errors.mlStudioVPCId } onBlur={handleBlur} onChange={handleChange} @@ -903,7 +909,9 @@ const EnvironmentCreateForm = (props) => { label="(Optional) ML Studio Subnet ID(s)" placeholder="(Optional) Bring your own VPC - Specify Subnet ID (Hit enter after typing value)" onChange={(chip) => { - setFieldValue('mlStudioSubnetIds', [...chip]); + setFieldValue('mlStudioSubnetIds', [ + ...chip + ]); }} /> diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index 85665568e..dd98be682 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -8,6 +8,7 @@ import { CardHeader, CircularProgress, Container, + Divider, FormControlLabel, FormGroup, FormHelperText, @@ -230,9 +231,13 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId',{ + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { is: (value) => !!value, - then: Yup.array().min(1).required('At least 1 Subnet Id required if VPC Id specified') + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) }), mlStudioVPCId: Yup.string().nullable(), resourcePrefix: Yup.string() @@ -401,10 +406,12 @@ const EnvironmentEditForm = (props) => { name="mlStudioVPCId" fullWidth error={Boolean( - touched.mlStudioVPCId && errors.mlStudioVPCId + touched.mlStudioVPCId && + errors.mlStudioVPCId )} helperText={ - touched.mlStudioVPCId && errors.mlStudioVPCId + touched.mlStudioVPCId && + errors.mlStudioVPCId } onBlur={handleBlur} onChange={handleChange} @@ -427,7 +434,9 @@ const EnvironmentEditForm = (props) => { label="(Optional) ML Studio Subnet ID(s)" placeholder="(Optional) Bring your own VPC - Specify Subnet ID (Hit enter after typing value)" onChange={(chip) => { - setFieldValue('mlStudioSubnetIds', [...chip]); + setFieldValue('mlStudioSubnetIds', [ + ...chip + ]); }} /> From 0b2d6a7750804527c1aba7097b4ed445f48a0694 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 29 Nov 2023 17:34:14 -0500 Subject: [PATCH 05/38] Test new ML Studio Domain Views --- backend/dataall/base/cdkproxy/app.py | 6 +- .../dataall/base/cdkproxy/cdk_cli_wrapper.py | 3 +- .../core/environment/api/input_types.py | 4 -- .../core/environment/cdk/environment_stack.py | 3 +- .../modules/mlstudio/api/input_types.py | 19 ++++++ .../dataall/modules/mlstudio/api/mutations.py | 24 +++++++ .../dataall/modules/mlstudio/api/queries.py | 10 +++ backend/dataall/modules/mlstudio/api/types.py | 30 +++++++++ .../modules/mlstudio/db/mlstudio_models.py | 13 ++-- .../modules/Environments/components/index.js | 2 + .../modules/Environments/services/index.js | 3 + .../views/EnvironmentCreateForm.js | 62 +------------------ .../Environments/views/EnvironmentEditForm.js | 61 ------------------ .../Environments/views/EnvironmentView.js | 10 +++ 14 files changed, 113 insertions(+), 137 deletions(-) diff --git a/backend/dataall/base/cdkproxy/app.py b/backend/dataall/base/cdkproxy/app.py index b02760e06..688bdfdef 100644 --- a/backend/dataall/base/cdkproxy/app.py +++ b/backend/dataall/base/cdkproxy/app.py @@ -48,10 +48,10 @@ def create(): logger.info(f' **kwargs: {_data}') if _data: data = json.loads(_data) - logger.info(f" Kwargs: {_data}") + # logger.info(f" Kwargs: {_data}") else: data = {} - logger.info(" Kwargs: None provided") + # logger.info(" Kwargs: None provided") # Creating CDK target environment env = Environment(account=account, region=region) @@ -60,7 +60,7 @@ def create(): tbl = tabulate(table, headers=['Setting', 'Value']) # , tablefmt="fancy_grid") logger.info(tbl) - instanciate_stack(stack_name, app, appid, env=env, target_uri=target_uri, payload=data) + instanciate_stack(stack_name, app, appid, env=env, target_uri=target_uri) app.synth() diff --git a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py index 467854a60..e12454ba1 100644 --- a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py +++ b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py @@ -6,7 +6,6 @@ import ast import logging import os -import json import subprocess import sys from abc import abstractmethod @@ -165,7 +164,7 @@ def deploy_cdk_stack(engine: Engine, stackid: str, app_path: str = None, path: s '-c', f"target_uri='{stack.targetUri}'", '-c', - f"data='{json.dumps(stack.payload)}'", + "data='{}'", # skips synth step when no changes apply '--app', f'"{sys.executable} {app_path}"', diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 435f1dbdc..891682bf5 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -28,8 +28,6 @@ gql.Argument('description', gql.String), gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), - gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) @@ -44,8 +42,6 @@ gql.Argument('description', gql.String), gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), - gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) ], diff --git a/backend/dataall/core/environment/cdk/environment_stack.py b/backend/dataall/core/environment/cdk/environment_stack.py index 2fb43e4fd..e55971c45 100644 --- a/backend/dataall/core/environment/cdk/environment_stack.py +++ b/backend/dataall/core/environment/cdk/environment_stack.py @@ -105,7 +105,7 @@ def get_environment_admins_group(engine, environment: Environment) -> [Environme group_uri=environment.SamlGroupName, ) - def __init__(self, scope, id, target_uri: str = None, payload: dict = {}, **kwargs): + def __init__(self, scope, id, target_uri: str = None, **kwargs): super().__init__( scope, id, @@ -117,7 +117,6 @@ def __init__(self, scope, id, target_uri: str = None, payload: dict = {}, **kwar **kwargs, ) # Read input - self.payload = payload self.target_uri = target_uri self.pivot_role_name = SessionHelper.get_delegation_role_name() self.external_id = SessionHelper.get_external_id_secret() diff --git a/backend/dataall/modules/mlstudio/api/input_types.py b/backend/dataall/modules/mlstudio/api/input_types.py index f05fd53f6..e19c3eb1c 100644 --- a/backend/dataall/modules/mlstudio/api/input_types.py +++ b/backend/dataall/modules/mlstudio/api/input_types.py @@ -33,3 +33,22 @@ gql.Argument('offset', gql.Integer), ], ) + +SagemakerStudioDomainFilter = gql.InputType( + name='SagemakerStudioDomainFilter', + arguments=[ + gql.Argument('term', gql.String), + gql.Argument(name='page', type=gql.Integer), + gql.Argument(name='pageSize', type=gql.Integer), + ], +) + +NewStudioDomainInput = gql.InputType( + name='NewStudioDomainInput', + arguments=[ + gql.Argument('label', gql.NonNullableType(gql.String)), + gql.Argument('environmentUri', gql.NonNullableType(gql.String)), + gql.Argument('subnetIds', gql.ArrayType(gql.String)), + gql.Argument('vpcId', gql.String), + ], +) diff --git a/backend/dataall/modules/mlstudio/api/mutations.py b/backend/dataall/modules/mlstudio/api/mutations.py index abcc3cc99..942a3a1ff 100644 --- a/backend/dataall/modules/mlstudio/api/mutations.py +++ b/backend/dataall/modules/mlstudio/api/mutations.py @@ -29,3 +29,27 @@ type=gql.String, resolver=delete_sagemaker_studio_user, ) + +createMLStudioDomain = gql.MutationField( + name='createMLStudioDomain', + args=[ + gql.Argument( + name='input', + type=gql.NonNullableType(gql.Ref('NewStudioDomainInput')), + ) + ], + type=gql.Ref('SagemakerStudioDomain'), + resolver=create_sagemaker_studio_domain, +) + +deleteMLStudioDomain = gql.MutationField( + name='deleteMLStudioDomain', + args=[ + gql.Argument( + name='sagemakerStudioUri', + type=gql.NonNullableType(gql.String), + ) + ], + type=gql.Boolean, + resolver=delete_sagemaker_studio_domain, +) diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index 457559def..9c7c6ba5d 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -34,3 +34,13 @@ type=gql.String, resolver=get_sagemaker_studio_user_presigned_url, ) + +listEnvironmentMLStudioDomains = gql.QueryField( + name='listEnvironmentMLStudioDomains', + args=[ + gql.Argument('filter', gql.Ref('SagemakerStudioDomainFilter')), + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + ], + type=gql.Ref('SagemakerStudioUserSearchResult'), + resolver=list_environment_sagemaker_studio_domains, +) \ No newline at end of file diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 21290711e..1829feda3 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -79,3 +79,33 @@ gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)), ], ) + +SagemakerStudioDomain = gql.ObjectType( + name='SagemakerStudioDomain', + fields=[ + gql.Field(name='sagemakerStudioUri', type=gql.ID), + gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Field(name='label', type=gql.String), + gql.Field(name='name', type=gql.String), + gql.Field(name='created', type=gql.String), + gql.Field(name='updated', type=gql.String), + gql.Field(name='SamlAdminGroupName', type=gql.String), + gql.Field( + name='environment', + type=gql.Ref('Environment'), + resolver=resolve_environment, + ) + ], +) + +SagemakerStudioDomainSearchResult = gql.ObjectType( + name='SagemakerStudioDomainSearchResult', + fields=[ + gql.Field(name='count', type=gql.Integer), + gql.Field(name='page', type=gql.Integer), + gql.Field(name='pages', type=gql.Integer), + gql.Field(name='hasNext', type=gql.Boolean), + gql.Field(name='hasPrevious', type=gql.Boolean), + gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioDomain)), + ], +) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 032826588..172ef6e1c 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression +from sqlalchemy.dialects.postgresql import ARRAY from dataall.base.db import Base from dataall.base.db import Resource, utils @@ -10,16 +11,20 @@ class SagemakerStudioDomain(Resource, Base): """Describes ORM model for sagemaker ML Studio domain""" __tablename__ = 'sagemaker_studio_domain' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri")) sagemakerStudioUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudio') ) - sagemakerStudioDomainID = Column(String, nullable=False) - SagemakerStudioStatus = Column(String, nullable=False) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) RoleArn = Column(String, nullable=False) region = Column(String, default='eu-west-1') - userRoleForSagemakerStudio = query_expression() + vpcType = Column(String, nullable=False) + vpcId = Column(String, nullable=False) + subnetIds = Column(ARRAY(String), nullable=False) + class SagemakerStudioUser(Resource, Base): diff --git a/frontend/src/modules/Environments/components/index.js b/frontend/src/modules/Environments/components/index.js index afccd1235..e8e41c362 100644 --- a/frontend/src/modules/Environments/components/index.js +++ b/frontend/src/modules/Environments/components/index.js @@ -12,3 +12,5 @@ export * from './EnvironmentTeamInviteEditForm'; export * from './EnvironmentTeamInviteForm'; export * from './EnvironmentTeams'; export * from './NetworkCreateModal'; +export * from './EnvironmentMLStudio'; +export * from './MLStudioDomainCreateModal'; diff --git a/frontend/src/modules/Environments/services/index.js b/frontend/src/modules/Environments/services/index.js index 14f5b659f..dbdd41431 100644 --- a/frontend/src/modules/Environments/services/index.js +++ b/frontend/src/modules/Environments/services/index.js @@ -22,3 +22,6 @@ export * from './removeConsumptionRole'; export * from './removeGroup'; export * from './updateEnvironment'; export * from './updateGroupEnvironmentPermissions'; +export * from './createMLStudioDomain'; +export * from './deleteMLStudioDomain'; +export * from './listEnvironmentMLStudioDomains'; diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 451aea7c2..4a97500d6 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -179,8 +179,6 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, - mlStudioVPCId: values.mlStudioVPCId, - mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -486,9 +484,7 @@ const EnvironmentCreateForm = (props) => { mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO), pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', - resourcePrefix: 'dataall', - mlStudioVPCId: '', - mlStudioSubnetIds: [] + resourcePrefix: 'dataall' }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -512,15 +508,6 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { - is: (value) => !!value, - then: Yup.array() - .min(1) - .required( - 'At least 1 Subnet Id required if VPC Id specified' - ) - }), - mlStudioVPCId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() @@ -870,53 +857,6 @@ const EnvironmentCreateForm = (props) => { variant="outlined" /> - {values.mlStudiosEnabled && ( - <> - - - - - - { - setFieldValue('mlStudioSubnetIds', [ - ...chip - ]); - }} - /> - - - )} {errors.submit && ( diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index dd98be682..caa5d8441 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -8,7 +8,6 @@ import { CardHeader, CircularProgress, Container, - Divider, FormControlLabel, FormGroup, FormHelperText, @@ -82,8 +81,6 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, - mlStudioVPCId: values.mlStudioVPCId, - mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -216,8 +213,6 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], - mlStudioVPCId: '', - mlStudioSubnetIds: [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -231,15 +226,6 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { - is: (value) => !!value, - then: Yup.array() - .min(1) - .required( - 'At least 1 Subnet Id required if VPC Id specified' - ) - }), - mlStudioVPCId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -395,53 +381,6 @@ const EnvironmentEditForm = (props) => { variant="outlined" /> - {values.mlStudiosEnabled && ( - <> - - - - - - { - setFieldValue('mlStudioSubnetIds', [ - ...chip - ]); - }} - /> - - - )} {isAnyEnvironmentModuleEnabled() && ( diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 0ba724320..792918c13 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -39,6 +39,7 @@ import { archiveEnvironment, getEnvironment } from '../services'; import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared'; import { EnvironmentDatasets, + EnvironmentMLStudio, EnvironmentOverview, EnvironmentSubscriptions, EnvironmentTeams, @@ -59,6 +60,12 @@ const tabs = [ icon: , active: isModuleEnabled(ModuleNames.DATASETS) }, + { + label: 'ML Studio Domain', + value: 'mlstudio', + icon: , + active: isModuleEnabled(ModuleNames.MLSTUDIO) + }, { label: 'Networks', value: 'networks', icon: }, { label: 'Subscriptions', @@ -267,6 +274,9 @@ const EnvironmentView = () => { fetchItem={fetchItem} /> )} + {isAdmin && currentTab === 'mlstudio' && ( + + )} {isAdmin && currentTab === 'tags' && ( Date: Wed, 29 Nov 2023 17:41:37 -0500 Subject: [PATCH 06/38] Test new ML Studio Domain Views - add missing files --- backend/dataall/base/cdkproxy/app.py | 2 +- .../components/EnvironmentMLStudio.js | 235 ++++++++++++++++++ .../components/MLStudioDomainCreateModal.js | 206 +++++++++++++++ .../services/createMLStudioDomain.js | 22 ++ .../services/deleteMLStudioDomain.js | 12 + .../listEnvironmentMLStudioDomains.js | 36 +++ 6 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 frontend/src/modules/Environments/components/EnvironmentMLStudio.js create mode 100644 frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js create mode 100644 frontend/src/modules/Environments/services/createMLStudioDomain.js create mode 100644 frontend/src/modules/Environments/services/deleteMLStudioDomain.js create mode 100644 frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js diff --git a/backend/dataall/base/cdkproxy/app.py b/backend/dataall/base/cdkproxy/app.py index 688bdfdef..080295841 100644 --- a/backend/dataall/base/cdkproxy/app.py +++ b/backend/dataall/base/cdkproxy/app.py @@ -51,7 +51,7 @@ def create(): # logger.info(f" Kwargs: {_data}") else: data = {} - # logger.info(" Kwargs: None provided") + # logger.info(f" Kwargs: None provided") # Creating CDK target environment env = Environment(account=account, region=region) diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js new file mode 100644 index 000000000..90e3e6d4b --- /dev/null +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -0,0 +1,235 @@ +import { LoadingButton } from '@mui/lab'; +import { + Box, + Card, + CardHeader, + Chip, + Divider, + Grid, + Table, + TableBody, + TableCell, + TableHead, + TableRow +} from '@mui/material'; +import CircularProgress from '@mui/material/CircularProgress'; +import { useSnackbar } from 'notistack'; +import PropTypes from 'prop-types'; +import React, { useCallback, useEffect, useState } from 'react'; +import { FaNetworkWired } from 'react-icons/fa'; +import { Defaults, Pager, PlusIcon, RefreshTableMenu, Scrollbar } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { useClient } from 'services'; +import { + deleteMLStudioDomain, + listEnvironmentMLStudioDomains +} from '../services'; +import { MLStudioDomainCreateModal } from './MLStudioDomainCreateModal'; + +function DomainRow({ domain }) { + return ( + + {domain.label} + {domain.sagemakerStudioDomainName} + {domain.VpcId} + + {domain.subnetIds && ( + + {domain.subnetIds.map((subnet) => ( + + ))} + + )} + + + ); +} + +DomainRow.propTypes = { + domain: PropTypes.any +}; +export const EnvironmentMLStudio = ({ environment }) => { + const client = useClient(); + const dispatch = useDispatch(); + const { enqueueSnackbar } = useSnackbar(); + const [items, setItems] = useState(Defaults.pagedResponse); + const [filter, setFilter] = useState(Defaults.filter); + const [loading, setLoading] = useState(true); + const [isStudioDomainCreateOpen, setStudioDomainCreateOpen] = useState(false); + const handleStudioDomainCreateModalOpen = () => { + setStudioDomainCreateOpen(true); + }; + + const handleStudioDomainCreateModalClose = () => { + setStudioDomainCreateOpen(false); + }; + + const fetchItems = useCallback(async () => { + try { + const response = await client.query( + listEnvironmentMLStudioDomains({ + environmentUri: environment.environmentUri, + filter + }) + ); + if (!response.errors) { + setItems({ ...response.data.listEnvironmentMLStudioDomains }); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (e) { + dispatch({ type: SET_ERROR, error: e.message }); + } finally { + setLoading(false); + } + }, [client, dispatch, filter, environment.environmentUri]); + + const deleteEnvironmentMLStudioDomain = async (sagemakerStudioUri) => { + const response = await client.mutate( + deleteMLStudioDomain({ sagemakerStudioUri: sagemakerStudioUri }) + ); + if (!response.errors) { + enqueueSnackbar('ML Studio Domain deleted', { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + }); + fetchItems().catch((e) => + dispatch({ type: SET_ERROR, error: e.message }) + ); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + }; + + useEffect(() => { + if (client) { + fetchItems().catch((e) => + dispatch({ type: SET_ERROR, error: e.message }) + ); + } + }, [client, filter.page, fetchItems, dispatch]); + + const handlePageChange = async (event, value) => { + if (value <= items.pages && value !== items.page) { + await setFilter({ ...filter, page: value }); + } + }; + + return ( + + + } + title={ + + ML Studio + Domains + + } + /> + + + + {items.nodes.length === 0 ? ( + } + sx={{ m: 1 }} + variant="outlined" + > + Add ML Studio Domain + + ) : ( + } + sx={{ m: 1 }} + variant="outlined" + > + Delete ML Studio Domain + + )} + + + + + + + + Name + Domain Name + VPC + Subnets + + + {loading ? ( + + ) : ( + + {items.nodes.length > 0 ? ( + items.nodes.map((domain) => ( + + )) + ) : ( + + No SageMaker Studio Domain Found + + )} + + )} +
+ {!loading && items.nodes.length > 0 && ( + + )} +
+
+
+ {isStudioDomainCreateOpen && ( + + )} +
+ ); +}; + +EnvironmentMLStudio.propTypes = { + environment: PropTypes.object.isRequired +}; diff --git a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js new file mode 100644 index 000000000..249d69f1c --- /dev/null +++ b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js @@ -0,0 +1,206 @@ +import { LoadingButton } from '@mui/lab'; +import { + Box, + CardContent, + CardHeader, + Dialog, + FormHelperText, + Grid, + TextField, + Typography +} from '@mui/material'; +import { Formik } from 'formik'; +import { useSnackbar } from 'notistack'; +import PropTypes from 'prop-types'; +import * as Yup from 'yup'; +import { ChipInput } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { useClient } from 'services'; +import { createMLStudioDomain } from '../services'; + +export const MLStudioDomainCreateModal = (props) => { + const { environment, onApply, onClose, open, reloadStudioDomains, ...other } = + props; + const { enqueueSnackbar } = useSnackbar(); + const dispatch = useDispatch(); + const client = useClient(); + + async function submit(values, setStatus, setSubmitting, setErrors) { + try { + const response = await client.mutate( + createMLStudioDomain({ + environmentUri: environment.environmentUri, + label: values.label, + vpcId: values.mlStudioVPCId, + subnetIds: values.mlStudioSubnetIds + }) + ); + if (!response.errors) { + setStatus({ success: true }); + setSubmitting(false); + enqueueSnackbar('Network added', { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + }); + if (reloadStudioDomains) { + reloadStudioDomains(); + } + if (onApply) { + onApply(); + } + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (err) { + setStatus({ success: false }); + setErrors({ submit: err.message }); + setSubmitting(false); + dispatch({ type: SET_ERROR, error: err.message }); + } + } + + if (!environment) { + return null; + } + + return ( + + + + Create a SageMaker ML Studio Domain for your Environment + + + !!value, + then: Yup.array() + .min(1) + .required('At least 1 Subnet Id required if VPC Id specified') + }) + })} + onSubmit={async ( + values, + { setErrors, setStatus, setSubmitting } + ) => { + await submit(values, setStatus, setSubmitting, setErrors); + }} + > + {({ + errors, + handleBlur, + handleChange, + handleSubmit, + isSubmitting, + setFieldValue, + touched, + values + }) => ( +
+ + + + + + + + + + + + { + setFieldValue('mlStudioSubnetIds', [...chip]); + }} + /> + + + + + {errors.submit && ( + + {errors.submit} + + )} + + + Create + + + + +
+ )} +
+
+
+
+ ); +}; + +MLStudioDomainCreateModal.propTypes = { + environment: PropTypes.object.isRequired, + onApply: PropTypes.func, + onClose: PropTypes.func, + reloadStudioDomains: PropTypes.func, + open: PropTypes.bool.isRequired +}; diff --git a/frontend/src/modules/Environments/services/createMLStudioDomain.js b/frontend/src/modules/Environments/services/createMLStudioDomain.js new file mode 100644 index 000000000..52b28eece --- /dev/null +++ b/frontend/src/modules/Environments/services/createMLStudioDomain.js @@ -0,0 +1,22 @@ +import { gql } from 'apollo-boost'; + +export const createMLStudioDomain = (input) => ({ + variables: { + input + }, + mutation: gql` + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + vpcUri + VpcId + label + description + tags + owner + SamlGroupName + privateSubnetIds + privateSubnetIds + } + } + ` +}); diff --git a/frontend/src/modules/Environments/services/deleteMLStudioDomain.js b/frontend/src/modules/Environments/services/deleteMLStudioDomain.js new file mode 100644 index 000000000..2a0c6e7d7 --- /dev/null +++ b/frontend/src/modules/Environments/services/deleteMLStudioDomain.js @@ -0,0 +1,12 @@ +import { gql } from 'apollo-boost'; + +export const deleteMLStudioDomain = ({ sagemakerStudioUri }) => ({ + variables: { + sagemakerStudioUri + }, + mutation: gql` + mutation deleteMLStudioDomain($sagemakerStudioUri: String!) { + deleteMLStudioDomain(sagemakerStudioUri: $sagemakerStudioUri) + } + ` +}); diff --git a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js new file mode 100644 index 000000000..dce649d45 --- /dev/null +++ b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js @@ -0,0 +1,36 @@ +import { gql } from 'apollo-boost'; + +export const listEnvironmentMLStudioDomains = ({ filter, environmentUri }) => ({ + variables: { + environmentUri, + filter + }, + query: gql` + query listEnvironmentMLStudioDomains( + $filter: VpcFilter + $environmentUri: String! + ) { + listEnvironmentMLStudioDomains( + environmentUri: $environmentUri + filter: $filter + ) { + count + page + pages + hasNext + hasPrevious + nodes { + VpcId + vpcUri + label + name + default + SamlGroupName + publicSubnetIds + privateSubnetIds + region + } + } + } + ` +}); From 973386a3e4daa11779718b0dc2b8a7232b011c65 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 29 Nov 2023 18:18:26 -0500 Subject: [PATCH 07/38] Test new ML Studio Domain Views - add missing files --- ...f5de322f_update_sagemaker_studio_domain.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py new file mode 100644 index 000000000..eb522ba91 --- /dev/null +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -0,0 +1,114 @@ +"""env_mlstudio_domain_table + +Revision ID: 71a5f5de322f +Revises: 8c79fb896983 +Create Date: 2023-11-29 09:44:04.160286 + +""" +import os +from sqlalchemy import orm, Column, String, Boolean, ForeignKey, DateTime, and_, inspect +from sqlalchemy.orm import query_expression +from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy as sa +from alembic import op + +from sqlalchemy.dialects import postgresql +from dataall.base.db import get_engine, has_table +from dataall.base.db import utils, Resource + +# revision identifiers, used by Alembic. +revision = '71a5f5de322f' +down_revision = '8c79fb896983' +branch_labels = None +depends_on = None + +Base = declarative_base() + + +def upgrade(): + """ + The script does the following migration: + 1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information + """ + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("Updating sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + new_column_name='sagemakerStudioDomainID', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + new_column_name='SagemakerStudioStatus', + nullable=True, + existing_type=sa.String() + ) + + op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), default=True)) + op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), default=True)) + op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), default=True)) + op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), default=True)) + + op.create_foreign_key( + f"fk_sagemaker_studio_domain_env_uri", + "sagemaker_studio_domain", "environment", + ["environmentUri"], ["environmentUri"], + ) + + session.commit() + print("Update of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to upgrade due to:', exception) + raise exception + + +def downgrade(): + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("Updating of sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + new_column_name='sagemakerStudioDomainID', + nullable=False, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + new_column_name='SagemakerStudioStatus', + nullable=False, + existing_type=sa.String() + ) + + op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName") + op.drop_column("sagemaker_studio_domain", "vpcType") + op.drop_column("sagemaker_studio_domain", "vpcId") + op.drop_column("sagemaker_studio_domain", "subnetIds") + + op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") + + session.commit() + print("Update of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to downgrade due to:', exception) + raise exception + From 2de32a4a5fdd004a265b08e74546f6874f48bc18 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 30 Nov 2023 12:30:12 -0500 Subject: [PATCH 08/38] Write backend resolvers for ml studio domain apis --- backend/dataall/base/aws/vpc.py | 44 ---------- .../dataall/core/environment/api/resolvers.py | 25 +----- .../dataall/modules/mlstudio/api/mutations.py | 2 + .../dataall/modules/mlstudio/api/queries.py | 9 +- .../dataall/modules/mlstudio/api/resolvers.py | 41 +++++++++- backend/dataall/modules/mlstudio/api/types.py | 5 +- .../modules/mlstudio/aws/ec2_client.py | 29 ++++++- .../mlstudio/cdk/mlstudio_extension.py | 4 +- .../modules/mlstudio/db/mlstudio_models.py | 1 - .../mlstudio/db/mlstudio_repositories.py | 60 +++++++++++++- .../mlstudio/services/mlstudio_service.py | 82 +++++++++++++++++++ ...f5de322f_update_sagemaker_studio_domain.py | 9 +- .../components/EnvironmentMLStudio.js | 2 +- .../services/createMLStudioDomain.js | 13 ++- .../listEnvironmentMLStudioDomains.js | 15 ++-- 15 files changed, 238 insertions(+), 103 deletions(-) delete mode 100644 backend/dataall/base/aws/vpc.py diff --git a/backend/dataall/base/aws/vpc.py b/backend/dataall/base/aws/vpc.py deleted file mode 100644 index d543cac8d..000000000 --- a/backend/dataall/base/aws/vpc.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging - -from botocore.exceptions import ClientError - -from .sts import SessionHelper - -log = logging.getLogger(__name__) - - -class VPCManager: - def __init__(self): - pass - - @staticmethod - def client(AwsAccountId, region, role=None): - session = SessionHelper.remote_session(accountid=AwsAccountId, role=role) - return session.client('ec2', region_name=region) - - @staticmethod - def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): - try: - ec2 = VPCManager.client(AwsAccountId=AwsAccountId, region=region, role=role) - response = ec2.describe_vpcs(VpcIds=[vpc_id]) - except ClientError as e: - log.exception(f'VPC Id {vpc_id} Not Found: {e}') - raise Exception(f'VPCNotFound: {vpc_id}') - - try: - if subnet_ids: - response = ec2.describe_subnets( - Filters=[ - { - 'Name': 'vpc-id', - 'Values': [vpc_id] - }, - ], - SubnetIds=subnet_ids - ) - except ClientError as e: - log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') - raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') - - if not subnet_ids or len(response['Subnets']) != len(subnet_ids): - raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index da50735b3..e4201950e 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -8,7 +8,6 @@ from sqlalchemy import and_, exc from dataall.base.aws.iam import IAM -from dataall.base.aws.vpc import VPCManager from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper from dataall.base.utils import Parameter @@ -44,7 +43,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs): return True if ssm_param == "True" else False -def check_environment(context: Context, source, account_id, region, data): +def check_environment(context: Context, source, account_id, region): """ Checks necessary resources for environment deployment. - Check CDKToolkit exists in Account assuming cdk_look_up_role - Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False @@ -72,21 +71,6 @@ def check_environment(context: Context, source, account_id, region, data): action='CHECK_PIVOT_ROLE', message='Pivot Role has not been created in the Environment AWS Account', ) - mlStudioEnabled = None - for parameter in data.get("parameters", []): - if parameter['key'] == 'mlStudiosEnabled': - mlStudioEnabled = parameter['value'] - - if mlStudioEnabled == 'true' and data.get("mlStudioVPCId", None): - log.info("Check if ML Studio VPC Exists in the Account") - - VPCManager.check_vpc_exists( - AwsAccountId=account_id, - region=region, - role=cdk_look_up_role_arn, - vpc_id=data.get("mlStudioVPCId", None), - subnet_ids=data.get('mlStudioSubnetIds', []), - ) return cdk_role_name @@ -101,8 +85,7 @@ def create_environment(context: Context, source, input={}): with context.engine.scoped_session() as session: cdk_role_name = check_environment(context, source, account_id=input.get('AwsAccountId'), - region=input.get('region'), - data=input + region=input.get('region') ) input['cdk_role_name'] = cdk_role_name @@ -117,10 +100,6 @@ def create_environment(context: Context, source, input={}): target_type='environment', target_uri=env.environmentUri, target_label=env.label, - payload={ - 'mlstudio_vpc_id': input.get('mlStudioVPCId', None), - 'mlstudio_subnet_ids': input.get('mlStudioSubnetIds', []), - }, ) stack_helper.deploy_stack(targetUri=env.environmentUri) env.userRoleInEnvironment = EnvironmentPermission.Owner.value diff --git a/backend/dataall/modules/mlstudio/api/mutations.py b/backend/dataall/modules/mlstudio/api/mutations.py index 942a3a1ff..62c03fdb6 100644 --- a/backend/dataall/modules/mlstudio/api/mutations.py +++ b/backend/dataall/modules/mlstudio/api/mutations.py @@ -3,6 +3,8 @@ from dataall.modules.mlstudio.api.resolvers import ( create_sagemaker_studio_user, delete_sagemaker_studio_user, + create_sagemaker_studio_domain, + delete_sagemaker_studio_domain ) createSagemakerStudioUser = gql.MutationField( diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index 9c7c6ba5d..dd9d647ab 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -4,6 +4,7 @@ get_sagemaker_studio_user, list_sagemaker_studio_users, get_sagemaker_studio_user_presigned_url, + list_environment_sagemaker_studio_domains ) getSagemakerStudioUser = gql.QueryField( @@ -38,9 +39,9 @@ listEnvironmentMLStudioDomains = gql.QueryField( name='listEnvironmentMLStudioDomains', args=[ - gql.Argument('filter', gql.Ref('SagemakerStudioDomainFilter')), - gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Argument('filter', gql.Ref('SagemakerStudioDomainFilter')), + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), ], - type=gql.Ref('SagemakerStudioUserSearchResult'), + type=gql.Ref('SagemakerStudioDomainSearchResult'), resolver=list_environment_sagemaker_studio_domains, -) \ No newline at end of file +) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index 63dc25ed7..ee7401eb7 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -18,7 +18,7 @@ def required_uri(uri): raise exceptions.RequiredParameter('URI') @staticmethod - def validate_creation_request(data): + def validate_user_creation_request(data): required = RequestValidator._required if not data: raise exceptions.RequiredParameter('data') @@ -28,6 +28,16 @@ def validate_creation_request(data): required(data, "environmentUri") required(data, "SamlAdminGroupName") + @staticmethod + def validate_domain_creation_request(data): + required = RequestValidator._required + if not data: + raise exceptions.RequiredParameter('data') + if not data.get('label'): + raise exceptions.RequiredParameter('name') + + required(data, "environmentUri") + @staticmethod def _required(data: dict, name: str): if not data.get(name): @@ -36,7 +46,7 @@ def _required(data: dict, name: str): def create_sagemaker_studio_user(context: Context, source, input: dict = None): """Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS""" - RequestValidator.validate_creation_request(input) + RequestValidator.validate_user_creation_request(input) request = SagemakerStudioCreationRequest.from_dict(input) return SagemakerStudioService.create_sagemaker_studio_user( uri=input["environmentUri"], @@ -90,6 +100,33 @@ def delete_sagemaker_studio_user( ) +def create_sagemaker_studio_domain(context: Context, source, input: dict = None): + """Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS""" + RequestValidator.validate_domain_creation_request(input) + return SagemakerStudioService.create_sagemaker_studio_domain( + uri=input["environmentUri"], + data=input + ) + + +def delete_sagemaker_studio_domain( + context, + source: SagemakerStudioUser, + sagemakerStudioUri: str = None +): + RequestValidator.required_uri(sagemakerStudioUri) + return SagemakerStudioService.delete_sagemaker_studio_domain( + uri=sagemakerStudioUri + ) + + +def list_environment_sagemaker_studio_domains(context, source, filter: dict = None, environment_uri: str = None): + RequestValidator.required_uri(environment_uri) + if not filter: + filter = {} + return SagemakerStudioService.list_environment_sagemaker_studio_domains(filter=filter, environment_uri=environment_uri) + + def resolve_user_role(context: Context, source: SagemakerStudioUser): """ Resolves the role of the current user in reference with the SageMaker Studio User diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 1829feda3..9daf76d70 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -85,11 +85,14 @@ fields=[ gql.Field(name='sagemakerStudioUri', type=gql.ID), gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Field(name='sagemakerStudioDomainName', type=gql.String), gql.Field(name='label', type=gql.String), gql.Field(name='name', type=gql.String), gql.Field(name='created', type=gql.String), gql.Field(name='updated', type=gql.String), - gql.Field(name='SamlAdminGroupName', type=gql.String), + gql.Field(name='vpcType', type=gql.String), + gql.Field(name='vpcId', type=gql.String), + gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)), gql.Field( name='environment', type=gql.Ref('Environment'), diff --git a/backend/dataall/modules/mlstudio/aws/ec2_client.py b/backend/dataall/modules/mlstudio/aws/ec2_client.py index 3dc484254..23d290a7e 100644 --- a/backend/dataall/modules/mlstudio/aws/ec2_client.py +++ b/backend/dataall/modules/mlstudio/aws/ec2_client.py @@ -1,7 +1,7 @@ import logging from dataall.base.aws.sts import SessionHelper - +from botocore.exceptions import ClientError log = logging.getLogger(__name__) @@ -25,3 +25,30 @@ def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): if vpcs: return True return False + + @staticmethod + def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + try: + ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = ec2.describe_vpcs(VpcIds=[vpc_id]) + except ClientError as e: + log.exception(f'VPC Id {vpc_id} Not Found: {e}') + raise Exception(f'VPCNotFound: {vpc_id}') + + try: + if subnet_ids: + response = ec2.describe_subnets( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [vpc_id] + }, + ], + SubnetIds=subnet_ids + ) + except ClientError as e: + log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') + raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') + + if not subnet_ids or len(response['Subnets']) != len(subnet_ids): + raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 0112faaf8..6ca52d367 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -13,6 +13,7 @@ RemovalPolicy, ) from botocore.exceptions import ClientError +from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper @@ -31,7 +32,8 @@ def extent(setup: EnvironmentSetup): _environment = setup.environment() with setup.get_engine().scoped_session() as session: enabled = EnvironmentService.get_boolean_env_param(session, _environment, "mlStudiosEnabled") - if not enabled: + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, _environment.environmentUri) + if not enabled or not domain: return sagemaker_principals = [setup.default_role] + setup.group_roles diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 172ef6e1c..aeaffec06 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -26,7 +26,6 @@ class SagemakerStudioDomain(Resource, Base): subnetIds = Column(ARRAY(String), nullable=False) - class SagemakerStudioUser(Resource, Base): """Describes ORM model for sagemaker ML Studio user""" __tablename__ = 'sagemaker_studio_user_profile' diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 763ca6f92..0381cc3cd 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -2,14 +2,15 @@ DAO layer that encapsulates the logic and interaction with the database for ML Studio Provides the API to retrieve / update / delete ml studio """ +from typing import Optional from sqlalchemy import or_ from sqlalchemy.sql import and_ from sqlalchemy.orm import Query from dataall.base.db import paginate -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser from dataall.core.environment.services.environment_resource_manager import EnvironmentResource - +from dataall.base.db.exceptions import ObjectNotFound class SageMakerStudioRepository(EnvironmentResource): """DAO layer for ML Studio""" @@ -44,7 +45,7 @@ def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: ) return query - def paginated_sagemaker_studio_users(self, username, groups, filter=None) -> dict: + def paginated_sagemaker_studio_users(self, username, groups, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( query=self._query_user_sagemaker_studio_users(username, groups, filter), @@ -67,3 +68,56 @@ def count_resources(self, environment, group_uri): ) .count() ) + + def create_sagemaker_studio_domain(self, username, environment, data): + # TODO: BUILD ROLE ARN + domain = SagemakerStudioDomain( + label=data.get('label'), + owner=username, + description=data.get('description', 'No description provided'), + tags=data.get('tags', []), + environmentUri=environment.environmentUri, + AwsAccountId=environment.AwsAccountId, + region=environment.region, + SagemakerStudioStatus="PENDING", + RoleArn="TODO", + vpcType=data.get('vpcType'), + vpcId=data.get('vpcId'), + subnetIds=data.get('subnetIds', []) + ) + self._session.add(domain) + self._session.commit() + + def paginated_environment_sagemaker_studio_domains(self, uri, filter={}) -> dict: + """Returns a page of sagemaker studio users for a data.all user""" + return paginate( + query=self._query_environment_sagemaker_studio_domains(uri, filter), + page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), + page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), + ).to_dict() + + def _query_environment_sagemaker_studio_domains(self, uri, filter) -> Query: + query = self._session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == uri, + ) + if filter and filter.get('term'): + query = query.filter( + or_( + SagemakerStudioDomain.description.ilike( + filter.get('term') + '%%' + ), + SagemakerStudioDomain.label.ilike( + filter.get('term') + '%%' + ), + ) + ) + return query + + @staticmethod + def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if not domain: + return None + return domain diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 06750b822..b17119a1a 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -11,6 +11,7 @@ from dataall.core.environment.env_permission_checker import has_group_permission from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.permissions.db.resource_policy_repositories import ResourcePolicy +from dataall.core.permissions import permissions from dataall.core.permissions.permission_checker import has_resource_permission, has_tenant_permission from dataall.core.stacks.api import stack_helper from dataall.core.stacks.db.stack_repositories import Stack @@ -18,6 +19,9 @@ from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.aws.ec2_client import EC2 +from dataall.base.aws.sts import SessionHelper + from dataall.modules.mlstudio.services.mlstudio_permissions import ( MANAGE_SGMSTUDIO_USERS, CREATE_SGMSTUDIO_USER, @@ -77,10 +81,16 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak action=CREATE_SGMSTUDIO_USER, message=f'ML Studio feature is disabled for the environment {env.label}', ) + # FOR OLD ONES response = get_sagemaker_studio_domain( AwsAccountId=env.AwsAccountId, region=env.region ) + + # FOR NEW ONES (default, created, imported) + # - CHECK RDS FIRST + # - IF NOT BOTO3 + existing_domain = response.get('DomainId', False) if not existing_domain: @@ -135,6 +145,78 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak return sagemaker_studio_user + @staticmethod + @has_tenant_permission(permissions.MANAGE_ENVIRONMENTS) + @has_resource_permission(permissions.UPDATE_ENVIRONMENT) + def create_sagemaker_studio_domain(*, uri: str, data: dict): + context = get_context() + with context.db_engine.scoped_session() as session: + environment = EnvironmentService.get_environment_by_uri(session, uri) + enabled = EnvironmentService.get_boolean_env_param(session, environment, "pipelinesEnabled") + if not enabled: + raise exceptions.UnauthorizedOperation( + action=permissions.UPDATE_ENVIRONMENT, + message=f'ML Studio feature is disabled for the environment {environment.label}', + ) + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=environment.AwsAccountId, region=environment.region + ) + if data.get("vpcId", None): + SagemakerStudioService.check_mlstudio_domain_vpc( + account_id=environment.AwsAccountId, + region=environment.region, + cdk_look_up_role_arn=cdk_look_up_role_arn, + data=data + ) + data["vpcType"] = "imported" + elif EC2.check_default_vpc_exists( + AwsAccountId=environment.AwsAccountId, + region=environment.region, + role=cdk_look_up_role_arn, + ): + data["vpcType"] = "default" + else: + data["vpcType"] = "created" + + domain = SageMakerStudioRepository(session).create_sagemaker_studio_domain( + username=get_context().username, + environment=environment, + data=data, + ) + # TODO: DEPLOY ENV STACK + return domain + + @staticmethod + def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn: str, data: dict): + if data.get("mlStudioVPCId", None) and data.get("mlStudioVPCId", None): + EC2.check_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + vpc_id=data.get("vpcId", None), + subnet_ids=data.get('subnetIds', []), + ) + data["vpcType"] = "imported" + return True + + @staticmethod + @has_resource_permission(permissions.UPDATE_ENVIRONMENT) + def delete_sagemaker_studio_domain(*, uri: str): + with _session() as session: + domain = SageMakerStudioRepository.get_sagemaker_studio_domain(session, uri) + # TODO: CHECK NUMBER OF USERS BEFORE DELETE + session.delete(domain) + # TODO: DEPLOY ENV STACK + return domain + + @staticmethod + def list_environment_sagemaker_studio_domains(*, filter: dict, environment_uri: str) -> dict: + with _session() as session: + return SageMakerStudioRepository(session).paginated_environment_sagemaker_studio_domains( + uri=environment_uri, + filter=filter, + ) + @staticmethod def list_sagemaker_studio_users(*, filter: dict) -> dict: with _session() as session: diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index eb522ba91..fd65ee0e4 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -23,7 +23,7 @@ depends_on = None Base = declarative_base() - + def upgrade(): """ @@ -33,7 +33,7 @@ def upgrade(): try: envname = os.getenv('envname', 'local') engine = get_engine(envname=envname).engine - + bind = op.get_bind() session = orm.Session(bind=bind) @@ -60,7 +60,7 @@ def upgrade(): op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), default=True)) op.create_foreign_key( - f"fk_sagemaker_studio_domain_env_uri", + "fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain", "environment", ["environmentUri"], ["environmentUri"], ) @@ -104,11 +104,10 @@ def downgrade(): op.drop_column("sagemaker_studio_domain", "subnetIds") op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") - + session.commit() print("Update of sagemaker_studio_domain table is done") except Exception as exception: print('Failed to downgrade due to:', exception) raise exception - diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 90e3e6d4b..105e89139 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -31,7 +31,7 @@ function DomainRow({ domain }) { {domain.label} {domain.sagemakerStudioDomainName} - {domain.VpcId} + {domain.vpcId} {domain.subnetIds && ( ({ mutation: gql` mutation createMLStudioDomain($input: NewStudioDomainInput) { createMLStudioDomain(input: $input) { - vpcUri - VpcId + sagemakerStudioUri + environmentUri label - description - tags - owner - SamlGroupName - privateSubnetIds - privateSubnetIds + vpcType + vpcId + subnetIds } } ` diff --git a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js index dce649d45..a12a098bd 100644 --- a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js +++ b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js @@ -7,7 +7,7 @@ export const listEnvironmentMLStudioDomains = ({ filter, environmentUri }) => ({ }, query: gql` query listEnvironmentMLStudioDomains( - $filter: VpcFilter + $filter: SagemakerStudioDomainFilter $environmentUri: String! ) { listEnvironmentMLStudioDomains( @@ -20,15 +20,12 @@ export const listEnvironmentMLStudioDomains = ({ filter, environmentUri }) => ({ hasNext hasPrevious nodes { - VpcId - vpcUri + sagemakerStudioUri + environmentUri label - name - default - SamlGroupName - publicSubnetIds - privateSubnetIds - region + vpcType + vpcId + subnetIds } } } From 92120127a5f1db55a5b1bcf1fdd9f4de6ca8803c Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 30 Nov 2023 14:00:28 -0500 Subject: [PATCH 09/38] Fixes to API params and permissions checks --- .../dataall/modules/mlstudio/api/resolvers.py | 6 +- .../mlstudio/db/mlstudio_repositories.py | 8 ++- .../mlstudio/services/mlstudio_service.py | 24 +++++-- .../components/EnvironmentMLStudio.js | 64 +++++++++++-------- .../components/MLStudioDomainCreateModal.js | 2 - .../listEnvironmentMLStudioDomains.js | 1 + 6 files changed, 65 insertions(+), 40 deletions(-) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index ee7401eb7..706219d4b 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -120,11 +120,11 @@ def delete_sagemaker_studio_domain( ) -def list_environment_sagemaker_studio_domains(context, source, filter: dict = None, environment_uri: str = None): - RequestValidator.required_uri(environment_uri) +def list_environment_sagemaker_studio_domains(context, source, filter: dict = None, environmentUri: str = None): + RequestValidator.required_uri(environmentUri) if not filter: filter = {} - return SagemakerStudioService.list_environment_sagemaker_studio_domains(filter=filter, environment_uri=environment_uri) + return SagemakerStudioService.list_environment_sagemaker_studio_domains(filter=filter, environment_uri=environmentUri) def resolve_user_role(context: Context, source: SagemakerStudioUser): diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 0381cc3cd..e118177bc 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -70,16 +70,17 @@ def count_resources(self, environment, group_uri): ) def create_sagemaker_studio_domain(self, username, environment, data): - # TODO: BUILD ROLE ARN + # TODO: BUILD ROLE ARN Domain Name domain = SagemakerStudioDomain( label=data.get('label'), owner=username, description=data.get('description', 'No description provided'), tags=data.get('tags', []), environmentUri=environment.environmentUri, - AwsAccountId=environment.AwsAccountId, + AWSAccountId=environment.AwsAccountId, region=environment.region, SagemakerStudioStatus="PENDING", + sagemakerStudioDomainName=data.get('label'), RoleArn="TODO", vpcType=data.get('vpcType'), vpcId=data.get('vpcId'), @@ -113,6 +114,9 @@ def _query_environment_sagemaker_studio_domains(self, uri, filter) -> Query: ) return query + def find_sagemaker_studio_domain(self, uri) -> Optional[SagemakerStudioDomain]: + return self._session.query(SagemakerStudioDomain).get(uri) + @staticmethod def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index b17119a1a..6b6b221dc 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -152,7 +152,7 @@ def create_sagemaker_studio_domain(*, uri: str, data: dict): context = get_context() with context.db_engine.scoped_session() as session: environment = EnvironmentService.get_environment_by_uri(session, uri) - enabled = EnvironmentService.get_boolean_env_param(session, environment, "pipelinesEnabled") + enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") if not enabled: raise exceptions.UnauthorizedOperation( action=permissions.UPDATE_ENVIRONMENT, @@ -188,7 +188,7 @@ def create_sagemaker_studio_domain(*, uri: str, data: dict): @staticmethod def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn: str, data: dict): - if data.get("mlStudioVPCId", None) and data.get("mlStudioVPCId", None): + if data.get("vpcId", None) and data.get("subnetIds", None): EC2.check_vpc_exists( AwsAccountId=account_id, region=region, @@ -200,14 +200,20 @@ def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn return True @staticmethod - @has_resource_permission(permissions.UPDATE_ENVIRONMENT) + def _get_domain_env_uri(session, uri): + domain = SagemakerStudioService._get_sagemaker_studio_domain(session, uri) + return domain.environmentUri + + @staticmethod + @has_tenant_permission(permissions.MANAGE_ENVIRONMENTS) + @has_resource_permission(permissions.UPDATE_ENVIRONMENT, parent_resource=_get_domain_env_uri) def delete_sagemaker_studio_domain(*, uri: str): with _session() as session: - domain = SageMakerStudioRepository.get_sagemaker_studio_domain(session, uri) + domain = SagemakerStudioService._get_sagemaker_studio_domain(session, uri) # TODO: CHECK NUMBER OF USERS BEFORE DELETE session.delete(domain) # TODO: DEPLOY ENV STACK - return domain + return True @staticmethod def list_environment_sagemaker_studio_domains(*, filter: dict, environment_uri: str) -> dict: @@ -283,3 +289,11 @@ def _get_sagemaker_studio_user(session, uri): if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user + + @staticmethod + def _get_sagemaker_studio_domain(session, uri): + domain = SageMakerStudioRepository(session).find_sagemaker_studio_domain(uri=uri) + if not domain: + raise exceptions.ObjectNotFound('SagemakerStudioDomain', uri) + return domain + diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 105e89139..a61b292dc 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -16,8 +16,14 @@ import CircularProgress from '@mui/material/CircularProgress'; import { useSnackbar } from 'notistack'; import PropTypes from 'prop-types'; import React, { useCallback, useEffect, useState } from 'react'; -import { FaNetworkWired } from 'react-icons/fa'; -import { Defaults, Pager, PlusIcon, RefreshTableMenu, Scrollbar } from 'design'; +import { + Defaults, + MinusIcon, + Pager, + PlusIcon, + RefreshTableMenu, + Scrollbar +} from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { useClient } from 'services'; import { @@ -97,7 +103,7 @@ export const EnvironmentMLStudio = ({ environment }) => { const deleteEnvironmentMLStudioDomain = async (sagemakerStudioUri) => { const response = await client.mutate( - deleteMLStudioDomain({ sagemakerStudioUri: sagemakerStudioUri }) + deleteMLStudioDomain({ sagemakerStudioUri }) ); if (!response.errors) { enqueueSnackbar('ML Studio Domain deleted', { @@ -136,8 +142,32 @@ export const EnvironmentMLStudio = ({ environment }) => { action={} title={ - ML Studio - Domains + ML Studio Domains + {items.nodes.length === 0 ? ( + } + sx={{ m: 1 }} + variant="outlined" + > + Add ML Studio Domain + + ) : ( + { + deleteEnvironmentMLStudioDomain( + items.nodes[0].sagemakerStudioUri + ); + }} + startIcon={} + sx={{ m: 1 }} + variant="outlined" + > + Delete ML Studio Domain + + )} } /> @@ -151,29 +181,7 @@ export const EnvironmentMLStudio = ({ environment }) => { p: 2 }} > - - {items.nodes.length === 0 ? ( - } - sx={{ m: 1 }} - variant="outlined" - > - Add ML Studio Domain - - ) : ( - } - sx={{ m: 1 }} - variant="outlined" - > - Delete ML Studio Domain - - )} - + diff --git a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js index 249d69f1c..c580eb709 100644 --- a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js +++ b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js @@ -169,8 +169,6 @@ export const MLStudioDomainCreateModal = (props) => { /> - - {errors.submit && ( {errors.submit} diff --git a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js index a12a098bd..a98dba4b2 100644 --- a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js +++ b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js @@ -26,6 +26,7 @@ export const listEnvironmentMLStudioDomains = ({ filter, environmentUri }) => ({ vpcType vpcId subnetIds + sagemakerStudioDomainName } } } From 56099d6acc2b7d23edd2792b8d8d74d23b7715fe Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 30 Nov 2023 15:22:48 -0500 Subject: [PATCH 10/38] Add Stack Deploy on create/delete studio domain --- .../dataall/base/utils/naming_convention.py | 1 + .../mlstudio/cdk/mlstudio_extension.py | 42 ++++--------------- .../mlstudio/db/mlstudio_repositories.py | 26 ++++++++++-- .../mlstudio/services/mlstudio_service.py | 4 +- .../components/MLStudioDomainCreateModal.js | 2 +- 5 files changed, 35 insertions(+), 40 deletions(-) diff --git a/backend/dataall/base/utils/naming_convention.py b/backend/dataall/base/utils/naming_convention.py index 3501fa71b..262964560 100644 --- a/backend/dataall/base/utils/naming_convention.py +++ b/backend/dataall/base/utils/naming_convention.py @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum): GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63} GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52} NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} + MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63} OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27} OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31} diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 6ca52d367..278acc11a 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -39,18 +39,10 @@ def extent(setup: EnvironmentSetup): sagemaker_principals = [setup.default_role] + setup.group_roles logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..') - existing_vpc_id = None - existing_subnet_ids = None - if setup.payload: - existing_vpc_id = setup.payload.get('mlstudio_vpc_id', None) - existing_subnet_ids = setup.payload.get('mlstudio_subnet_ids', []) - logger.info(f'VPC ID = {existing_vpc_id}') - logger.info(f'Subnet IDs = {existing_subnet_ids}') - - if existing_vpc_id and existing_subnet_ids: - logger.info(f'Using VPC {existing_vpc_id} and subnets {existing_subnet_ids} for SageMaker Studio domain') - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=existing_vpc_id) - subnet_ids = existing_subnet_ids + if domain.vpcId and domain.subnetIds: + logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) + subnet_ids = domain.subnetIds security_groups = [] else: cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( @@ -122,9 +114,9 @@ def extent(setup: EnvironmentSetup): sagemaker_domain_role = iam.Role( setup, - 'RoleForSagemakerStudioUsers', + domain.RoleArn, assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), - role_name='RoleSagemakerStudioUsers', + role_name=domain.RoleArn, managed_policies=[ iam.ManagedPolicy.from_managed_policy_arn( setup, @@ -191,8 +183,8 @@ def extent(setup: EnvironmentSetup): sagemaker_domain = sagemaker.CfnDomain( setup, - 'SagemakerStudioDomain', - domain_name=f'SagemakerStudioDomain-{_environment.region}-{_environment.AwsAccountId}', + domain.sagemakerStudioDomainName, + domain_name=domain.sagemakerStudioDomainName, auth_mode='IAM', default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( execution_role=sagemaker_domain_role.role_arn, @@ -217,21 +209,3 @@ def extent(setup: EnvironmentSetup): ) return sagemaker_domain - @staticmethod - def check_existing_sagemaker_studio_domain(environment): - logger.info('Check if there is an existing sagemaker studio domain in the account') - try: - logger.info('check sagemaker studio domain created as part of data.all environment stack.') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=environment.AwsAccountId, region=environment.region - ) - dataall_created_domain = ParameterStoreManager.client( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ).get_parameter(Name=f'/{environment.resourcePrefix}/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id') - return False - except ClientError as e: - logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}') - existing_domain = get_sagemaker_studio_domain( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ) - return existing_domain.get('DomainId', False) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index e118177bc..49dfcfc98 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -7,10 +7,14 @@ from sqlalchemy.sql import and_ from sqlalchemy.orm import Query +from dataall.base.utils import slugify from dataall.base.db import paginate from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser from dataall.core.environment.services.environment_resource_manager import EnvironmentResource -from dataall.base.db.exceptions import ObjectNotFound +from dataall.base.utils.naming_convention import ( + NamingConventionService, + NamingConventionPattern, +) class SageMakerStudioRepository(EnvironmentResource): """DAO layer for ML Studio""" @@ -80,8 +84,8 @@ def create_sagemaker_studio_domain(self, username, environment, data): AWSAccountId=environment.AwsAccountId, region=environment.region, SagemakerStudioStatus="PENDING", - sagemakerStudioDomainName=data.get('label'), - RoleArn="TODO", + RoleArn="DefaultMLStudioRole", + sagemakerStudioDomainName=slugify(data.get('label'), separator=''), vpcType=data.get('vpcType'), vpcId=data.get('vpcId'), subnetIds=data.get('subnetIds', []) @@ -89,6 +93,22 @@ def create_sagemaker_studio_domain(self, username, environment, data): self._session.add(domain) self._session.commit() + domain.sagemakerStudioDomainName = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=domain.label, + pattern=NamingConventionPattern.MLSTUDIO_DOMAIN, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + domain.RoleArn = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=f"DefaultMLStudioRole-{domain.label}", + pattern=NamingConventionPattern.IAM, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + return domain + def paginated_environment_sagemaker_studio_domains(self, uri, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 6b6b221dc..0c74e4804 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -183,7 +183,7 @@ def create_sagemaker_studio_domain(*, uri: str, data: dict): environment=environment, data=data, ) - # TODO: DEPLOY ENV STACK + stack_helper.deploy_stack(domain.environmentUri) return domain @staticmethod @@ -212,7 +212,7 @@ def delete_sagemaker_studio_domain(*, uri: str): domain = SagemakerStudioService._get_sagemaker_studio_domain(session, uri) # TODO: CHECK NUMBER OF USERS BEFORE DELETE session.delete(domain) - # TODO: DEPLOY ENV STACK + stack_helper.deploy_stack(domain.environmentUri) return True @staticmethod diff --git a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js index c580eb709..d592a4d67 100644 --- a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js +++ b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js @@ -38,7 +38,7 @@ export const MLStudioDomainCreateModal = (props) => { if (!response.errors) { setStatus({ success: true }); setSubmitting(false); - enqueueSnackbar('Network added', { + enqueueSnackbar('ML Studio Domain Added', { anchorOrigin: { horizontal: 'right', vertical: 'top' From ecb9be0312b0f18fc25cf4a08eff69b3eed6b8c6 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 30 Nov 2023 17:38:15 -0500 Subject: [PATCH 11/38] Fix Migration script and clean up naming and lint checks --- .../mlstudio/cdk/mlstudio_extension.py | 10 +-- .../modules/mlstudio/db/mlstudio_models.py | 8 +- .../mlstudio/db/mlstudio_repositories.py | 7 +- .../mlstudio/services/mlstudio_service.py | 3 +- ...f5de322f_update_sagemaker_studio_domain.py | 89 ++++++++++++++++--- .../components/EnvironmentMLStudio.js | 45 +++++----- 6 files changed, 111 insertions(+), 51 deletions(-) diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 278acc11a..166ac814b 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -104,6 +104,7 @@ def extent(setup: EnvironmentSetup): "SecurityGroup", vpc=vpc, description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, ) sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) @@ -114,9 +115,9 @@ def extent(setup: EnvironmentSetup): sagemaker_domain_role = iam.Role( setup, - domain.RoleArn, + 'RoleForSagemakerStudioUsers', assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), - role_name=domain.RoleArn, + role_name=domain.DefaultDomainRoleName, managed_policies=[ iam.ManagedPolicy.from_managed_policy_arn( setup, @@ -132,7 +133,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain_key = kms.Key( setup, 'SagemakerDomainKmsKey', - alias='SagemakerStudioDomain', + alias=domain.sagemakerStudioDomainName, enable_key_rotation=True, admins=[ iam.ArnPrincipal(_environment.CDKRoleArn) @@ -183,7 +184,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain = sagemaker.CfnDomain( setup, - domain.sagemakerStudioDomainName, + 'SagemakerStudioDomain', domain_name=domain.sagemakerStudioDomainName, auth_mode='IAM', default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( @@ -208,4 +209,3 @@ def extent(setup: EnvironmentSetup): parameter_name=f'/{_environment.resourcePrefix}/{_environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', ) return sagemaker_domain - diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index aeaffec06..c99e11eca 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -15,15 +15,15 @@ class SagemakerStudioDomain(Resource, Base): sagemakerStudioUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudio') ) - sagemakerStudioDomainID = Column(String, nullable=True) + sagemakerStudioDomainID = Column(String, nullable=False) SagemakerStudioStatus = Column(String, nullable=True) sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) RoleArn = Column(String, nullable=False) region = Column(String, default='eu-west-1') - vpcType = Column(String, nullable=False) - vpcId = Column(String, nullable=False) - subnetIds = Column(ARRAY(String), nullable=False) + vpcType = Column(String, nullable=True) + vpcId = Column(String, nullable=True) + subnetIds = Column(ARRAY(String), nullable=True) class SagemakerStudioUser(Resource, Base): diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 49dfcfc98..fadfff920 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -16,6 +16,7 @@ NamingConventionPattern, ) + class SageMakerStudioRepository(EnvironmentResource): """DAO layer for ML Studio""" _DEFAULT_PAGE = 1 @@ -84,7 +85,7 @@ def create_sagemaker_studio_domain(self, username, environment, data): AWSAccountId=environment.AwsAccountId, region=environment.region, SagemakerStudioStatus="PENDING", - RoleArn="DefaultMLStudioRole", + DefaultDomainRoleName="DefaultMLStudioRole", sagemakerStudioDomainName=slugify(data.get('label'), separator=''), vpcType=data.get('vpcType'), vpcId=data.get('vpcId'), @@ -100,9 +101,9 @@ def create_sagemaker_studio_domain(self, username, environment, data): resource_prefix=environment.resourcePrefix, ).build_compliant_name() - domain.RoleArn = NamingConventionService( + domain.DefaultDomainRoleName = NamingConventionService( target_uri=domain.sagemakerStudioUri, - target_label=f"DefaultMLStudioRole-{domain.label}", + target_label=domain.label, pattern=NamingConventionPattern.IAM, resource_prefix=environment.resourcePrefix, ).build_compliant_name() diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 0c74e4804..a7eb42e6b 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -289,11 +289,10 @@ def _get_sagemaker_studio_user(session, uri): if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user - + @staticmethod def _get_sagemaker_studio_domain(session, uri): domain = SageMakerStudioRepository(session).find_sagemaker_studio_domain(uri=uri) if not domain: raise exceptions.ObjectNotFound('SagemakerStudioDomain', uri) return domain - diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index fd65ee0e4..9c31cc7a2 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -6,8 +6,7 @@ """ import os -from sqlalchemy import orm, Column, String, Boolean, ForeignKey, DateTime, and_, inspect -from sqlalchemy.orm import query_expression +from sqlalchemy import orm, Column, String, Boolean, ForeignKey, and_ from sqlalchemy.ext.declarative import declarative_base import sqlalchemy as sa from alembic import op @@ -25,6 +24,35 @@ Base = declarative_base() +class Environment(Resource, Base): + __tablename__ = "environment" + environmentUri = Column(String, primary_key=True) + AwsAccountId = Column(Boolean) + region = Column(Boolean) + + +class EnvironmentParameter(Base): + __tablename__ = 'environment_parameters' + environmentUri = Column(String, primary_key=True) + key = Column('paramKey', String, primary_key=True) + value = Column('paramValue', String, nullable=True) + + +class SagemakerStudioDomain(Resource, Base): + __tablename__ = 'sagemaker_studio_domain' + environmentUri = Column(String, ForeignKey("environment.environmentUri")) + sagemakerStudioUri = Column( + String, primary_key=True, default=utils.uuid('sagemakerstudio') + ) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) + AWSAccountId = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) + region = Column(String, default='eu-west-1') + vpcType = Column(String, nullable=True) + + def upgrade(): """ The script does the following migration: @@ -42,22 +70,27 @@ def upgrade(): op.alter_column( 'sagemaker_studio_domain', 'sagemakerStudioDomainID', - new_column_name='sagemakerStudioDomainID', nullable=True, existing_type=sa.String() ) op.alter_column( 'sagemaker_studio_domain', 'SagemakerStudioStatus', - new_column_name='SagemakerStudioStatus', nullable=True, existing_type=sa.String() ) + op.alter_column( + 'sagemaker_studio_domain', + 'RoleArn', + new_column_name='DefaultDomainRoleName', + nullable=False, + existing_type=sa.String() + ) - op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), default=True)) - op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), default=True)) - op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), default=True)) - op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), default=True)) + op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False)) + op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) op.create_foreign_key( "fk_sagemaker_studio_domain_env_uri", @@ -65,8 +98,35 @@ def upgrade(): ["environmentUri"], ["environmentUri"], ) - session.commit() - print("Update of sagemaker_studio_domain table is done") + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" + ) + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown" + ) + session.add(domain) + session.flush() + session.commit() + print("Fill of sagemaker_studio_domain table is done") except Exception as exception: print('Failed to upgrade due to:', exception) @@ -86,14 +146,19 @@ def downgrade(): op.alter_column( 'sagemaker_studio_domain', 'sagemakerStudioDomainID', - new_column_name='sagemakerStudioDomainID', nullable=False, existing_type=sa.String() ) op.alter_column( 'sagemaker_studio_domain', 'SagemakerStudioStatus', - new_column_name='SagemakerStudioStatus', + nullable=False, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'DefaultDomainRoleName', + new_column_name='RoleArn', nullable=False, existing_type=sa.String() ) diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index a61b292dc..f3e979c9b 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -1,4 +1,5 @@ import { LoadingButton } from '@mui/lab'; +import { DeleteOutlined } from '@mui/icons-material'; import { Box, Card, @@ -6,6 +7,7 @@ import { Chip, Divider, Grid, + IconButton, Table, TableBody, TableCell, @@ -16,14 +18,7 @@ import CircularProgress from '@mui/material/CircularProgress'; import { useSnackbar } from 'notistack'; import PropTypes from 'prop-types'; import React, { useCallback, useEffect, useState } from 'react'; -import { - Defaults, - MinusIcon, - Pager, - PlusIcon, - RefreshTableMenu, - Scrollbar -} from 'design'; +import { Defaults, Pager, PlusIcon, RefreshTableMenu, Scrollbar } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { useClient } from 'services'; import { @@ -32,7 +27,7 @@ import { } from '../services'; import { MLStudioDomainCreateModal } from './MLStudioDomainCreateModal'; -function DomainRow({ domain }) { +function DomainRow({ domain, deleteEnvironmentMLStudioDomain }) { return ( {domain.label} @@ -58,12 +53,22 @@ function DomainRow({ domain }) { )} + + { + deleteEnvironmentMLStudioDomain(domain.sagemakerStudioUri); + }} + > + + + ); } DomainRow.propTypes = { - domain: PropTypes.any + domain: PropTypes.any, + deleteEnvironmentMLStudioDomain: PropTypes.func }; export const EnvironmentMLStudio = ({ environment }) => { const client = useClient(); @@ -143,7 +148,7 @@ export const EnvironmentMLStudio = ({ environment }) => { title={ ML Studio Domains - {items.nodes.length === 0 ? ( + {items.nodes.length === 0 && ( { > Add ML Studio Domain - ) : ( - { - deleteEnvironmentMLStudioDomain( - items.nodes[0].sagemakerStudioUri - ); - }} - startIcon={} - sx={{ m: 1 }} - variant="outlined" - > - Delete ML Studio Domain - )} } @@ -192,6 +183,7 @@ export const EnvironmentMLStudio = ({ environment }) => { Domain Name VPC Subnets + Actions {loading ? ( @@ -204,6 +196,9 @@ export const EnvironmentMLStudio = ({ environment }) => { domain={domain} environment={environment} fetchItems={fetchItems} + deleteEnvironmentMLStudioDomain={ + deleteEnvironmentMLStudioDomain + } /> )) ) : ( From 3280a056dd886723c90d28471b6c846d3dfb46a3 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Fri, 1 Dec 2023 16:24:14 -0500 Subject: [PATCH 12/38] Create Studio Domain on Env Create/Update, rework Frontend Views --- .../dataall/base/cdkproxy/cdk.context.json | 96 +++++++ .../core/environment/api/input_types.py | 9 +- .../dataall/core/environment/api/resolvers.py | 23 +- .../dataall/core/environment/aws/__init__.py | 0 .../environment}/aws/ec2_client.py | 0 backend/dataall/modules/mlstudio/__init__.py | 3 + .../dataall/modules/mlstudio/api/mutations.py | 13 +- .../dataall/modules/mlstudio/api/queries.py | 12 +- .../dataall/modules/mlstudio/api/resolvers.py | 15 +- backend/dataall/modules/mlstudio/api/types.py | 7 +- .../mlstudio/cdk/mlstudio_extension.py | 2 +- .../modules/mlstudio/db/mlstudio_models.py | 2 +- .../mlstudio/db/mlstudio_repositories.py | 60 +++-- .../mlstudio/services/mlstudio_service.py | 42 ++-- .../components/EnvironmentMLStudio.js | 236 ++++++++---------- .../components/MLStudioDomainCreateModal.js | 204 --------------- .../modules/Environments/components/index.js | 1 - .../deleteEnvironmentMLStudioDomain.js | 12 + .../services/deleteMLStudioDomain.js | 12 - .../services/getEnvironmentMLStudioDomain.js | 23 ++ .../modules/Environments/services/index.js | 3 +- .../views/EnvironmentCreateForm.js | 88 ++++++- .../Environments/views/EnvironmentEditForm.js | 120 ++++++++- 23 files changed, 547 insertions(+), 436 deletions(-) create mode 100644 backend/dataall/base/cdkproxy/cdk.context.json create mode 100644 backend/dataall/core/environment/aws/__init__.py rename backend/dataall/{modules/mlstudio => core/environment}/aws/ec2_client.py (100%) delete mode 100644 frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js create mode 100644 frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js delete mode 100644 frontend/src/modules/Environments/services/deleteMLStudioDomain.js create mode 100644 frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js diff --git a/backend/dataall/base/cdkproxy/cdk.context.json b/backend/dataall/base/cdkproxy/cdk.context.json new file mode 100644 index 000000000..6a02b4fe1 --- /dev/null +++ b/backend/dataall/base/cdkproxy/cdk.context.json @@ -0,0 +1,96 @@ +{ + "vpc-provider:account=139956106467:filter.isDefault=true:region=us-east-1:returnAsymmetricSubnets=true": { + "vpcId": "vpc-47a2473a", + "vpcCidrBlock": "172.31.0.0/16", + "ownerAccountId": "139956106467", + "availabilityZones": [], + "subnetGroups": [ + { + "name": "Public", + "type": "Public", + "subnets": [ + { + "subnetId": "subnet-ce854ca8", + "cidr": "172.31.0.0/20", + "availabilityZone": "us-east-1a", + "routeTableId": "rtb-eb234395" + }, + { + "subnetId": "subnet-dd2df9fc", + "cidr": "172.31.80.0/20", + "availabilityZone": "us-east-1b", + "routeTableId": "rtb-eb234395" + }, + { + "subnetId": "subnet-e357ceae", + "cidr": "172.31.16.0/20", + "availabilityZone": "us-east-1c", + "routeTableId": "rtb-eb234395" + }, + { + "subnetId": "subnet-9af53fc5", + "cidr": "172.31.32.0/20", + "availabilityZone": "us-east-1d", + "routeTableId": "rtb-eb234395" + }, + { + "subnetId": "subnet-95968bab", + "cidr": "172.31.48.0/20", + "availabilityZone": "us-east-1e", + "routeTableId": "rtb-eb234395" + }, + { + "subnetId": "subnet-6ba22165", + "cidr": "172.31.64.0/20", + "availabilityZone": "us-east-1f", + "routeTableId": "rtb-eb234395" + } + ] + } + ] + }, + "vpc-provider:account=139956106467:filter.vpc-id=vpc-09ddf78440e5c6d5d:region=us-east-1:returnAsymmetricSubnets=true": { + "vpcId": "vpc-09ddf78440e5c6d5d", + "vpcCidrBlock": "10.0.0.0/24", + "ownerAccountId": "139956106467", + "availabilityZones": [], + "subnetGroups": [ + { + "name": "Private", + "type": "Private", + "subnets": [ + { + "subnetId": "subnet-0f2c957fec49cc5b6", + "cidr": "10.0.0.128/28", + "availabilityZone": "us-east-1a", + "routeTableId": "rtb-0f5763c9bce96a6c3" + }, + { + "subnetId": "subnet-06d0ac5a5cdc3e842", + "cidr": "10.0.0.144/28", + "availabilityZone": "us-east-1b", + "routeTableId": "rtb-0e668f02b8963de94" + } + ] + }, + { + "name": "Public", + "type": "Public", + "subnets": [ + { + "subnetId": "subnet-0b83c25e072255092", + "cidr": "10.0.0.0/28", + "availabilityZone": "us-east-1a", + "routeTableId": "rtb-084b8fa8b6c24a230" + }, + { + "subnetId": "subnet-002ff94f2876021d5", + "cidr": "10.0.0.16/28", + "availabilityZone": "us-east-1b", + "routeTableId": "rtb-084b8fa8b6c24a230" + } + ] + } + ] + } +} diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 891682bf5..15786955b 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -30,8 +30,9 @@ gql.Argument('region', gql.NonNullableType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) - + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('mlStudioVPCId', gql.String), + gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)) ], ) @@ -43,7 +44,9 @@ gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('mlStudioVPCId', gql.String), + gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)) ], ) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index e4201950e..38bbbd32a 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -20,6 +20,7 @@ from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.db.stack_repositories import Stack from dataall.core.vpc.db.vpc_repositories import Vpc +from dataall.core.environment.aws.ec2_client import EC2 from dataall.base.db import exceptions from dataall.core.permissions import permissions from dataall.base.feature_toggle_checker import is_feature_enabled @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs): return True if ssm_param == "True" else False -def check_environment(context: Context, source, account_id, region): +def check_environment(context: Context, source, account_id, region, data): """ Checks necessary resources for environment deployment. - Check CDKToolkit exists in Account assuming cdk_look_up_role - Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False @@ -71,6 +72,20 @@ def check_environment(context: Context, source, account_id, region): action='CHECK_PIVOT_ROLE', message='Pivot Role has not been created in the Environment AWS Account', ) + mlStudioEnabled = None + for parameter in data.get("parameters", []): + if parameter['key'] == 'mlStudiosEnabled': + mlStudioEnabled = parameter['value'] + + if mlStudioEnabled and data.get("mlStudioVPCId", None) and data.get("mlStudioSubnetIds", []): + log.info("Check if ML Studio VPC Exists in the Account") + EC2.check_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + vpc_id=data.get("mlStudioVPCId", None), + subnet_ids=data.get('mlStudioSubnetIds', []), + ) return cdk_role_name @@ -85,7 +100,8 @@ def create_environment(context: Context, source, input={}): with context.engine.scoped_session() as session: cdk_role_name = check_environment(context, source, account_id=input.get('AwsAccountId'), - region=input.get('region') + region=input.get('region'), + data=input ) input['cdk_role_name'] = cdk_role_name @@ -120,7 +136,8 @@ def update_environment( environment = EnvironmentService.get_environment_by_uri(session, environmentUri) cdk_role_name = check_environment(context, source, account_id=environment.AwsAccountId, - region=environment.region + region=environment.region, + data=input ) previous_resource_prefix = environment.resourcePrefix diff --git a/backend/dataall/core/environment/aws/__init__.py b/backend/dataall/core/environment/aws/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/dataall/modules/mlstudio/aws/ec2_client.py b/backend/dataall/core/environment/aws/ec2_client.py similarity index 100% rename from backend/dataall/modules/mlstudio/aws/ec2_client.py rename to backend/dataall/core/environment/aws/ec2_client.py diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py index 2db9c0a1e..190267430 100644 --- a/backend/dataall/modules/mlstudio/__init__.py +++ b/backend/dataall/modules/mlstudio/__init__.py @@ -4,6 +4,7 @@ from dataall.base.loader import ImportMode, ModuleInterface from dataall.core.stacks.db.target_type_repositories import TargetType from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager log = logging.getLogger(__name__) @@ -20,6 +21,8 @@ def __init__(self): from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER) + EnvironmentResourceManager.register(SageMakerStudioRepository()) + log.info("API of sagemaker mlstudio has been imported") diff --git a/backend/dataall/modules/mlstudio/api/mutations.py b/backend/dataall/modules/mlstudio/api/mutations.py index 62c03fdb6..b14195f70 100644 --- a/backend/dataall/modules/mlstudio/api/mutations.py +++ b/backend/dataall/modules/mlstudio/api/mutations.py @@ -4,7 +4,7 @@ create_sagemaker_studio_user, delete_sagemaker_studio_user, create_sagemaker_studio_domain, - delete_sagemaker_studio_domain + delete_environment_sagemaker_studio_domain ) createSagemakerStudioUser = gql.MutationField( @@ -44,14 +44,11 @@ resolver=create_sagemaker_studio_domain, ) -deleteMLStudioDomain = gql.MutationField( - name='deleteMLStudioDomain', +deleteEnvironmentMLStudioDomain = gql.MutationField( + name='deleteEnvironmentMLStudioDomain', args=[ - gql.Argument( - name='sagemakerStudioUri', - type=gql.NonNullableType(gql.String), - ) + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), ], type=gql.Boolean, - resolver=delete_sagemaker_studio_domain, + resolver=delete_environment_sagemaker_studio_domain, ) diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index dd9d647ab..41c4e5cd1 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -4,7 +4,8 @@ get_sagemaker_studio_user, list_sagemaker_studio_users, get_sagemaker_studio_user_presigned_url, - list_environment_sagemaker_studio_domains + list_environment_sagemaker_studio_domains, + get_environment_sagemaker_studio_domain ) getSagemakerStudioUser = gql.QueryField( @@ -36,6 +37,15 @@ resolver=get_sagemaker_studio_user_presigned_url, ) +getEnvironmentMLStudioDomain = gql.QueryField( + name='getEnvironmentMLStudioDomain', + args=[ + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + ], + type=gql.Ref('SagemakerStudioDomain'), + resolver=get_environment_sagemaker_studio_domain, +) + listEnvironmentMLStudioDomains = gql.QueryField( name='listEnvironmentMLStudioDomains', args=[ diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index 706219d4b..2d20d41ff 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -109,17 +109,22 @@ def create_sagemaker_studio_domain(context: Context, source, input: dict = None) ) -def delete_sagemaker_studio_domain( +def delete_environment_sagemaker_studio_domain( context, source: SagemakerStudioUser, - sagemakerStudioUri: str = None + environmentUri: str = None ): - RequestValidator.required_uri(sagemakerStudioUri) - return SagemakerStudioService.delete_sagemaker_studio_domain( - uri=sagemakerStudioUri + RequestValidator.required_uri(environmentUri) + return SagemakerStudioService.delete_environment_sagemaker_studio_domain( + uri=environmentUri ) +def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None): + RequestValidator.required_uri(environmentUri) + return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri) + + def list_environment_sagemaker_studio_domains(context, source, filter: dict = None, environmentUri: str = None): RequestValidator.required_uri(environmentUri) if not filter: diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 9daf76d70..7446ef290 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -86,13 +86,16 @@ gql.Field(name='sagemakerStudioUri', type=gql.ID), gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), gql.Field(name='sagemakerStudioDomainName', type=gql.String), + gql.Field(name='DefaultDomainRoleName', type=gql.String), gql.Field(name='label', type=gql.String), gql.Field(name='name', type=gql.String), - gql.Field(name='created', type=gql.String), - gql.Field(name='updated', type=gql.String), gql.Field(name='vpcType', type=gql.String), gql.Field(name='vpcId', type=gql.String), gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)), + gql.Field(name='owner', type=gql.String), + gql.Field(name='created', type=gql.String), + gql.Field(name='updated', type=gql.String), + gql.Field(name='deleted', type=gql.String), gql.Field( name='environment', type=gql.Ref('Environment'), diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 166ac814b..37131390a 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -19,7 +19,7 @@ from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.modules.mlstudio.aws.ec2_client import EC2 +from dataall.core.environment.aws.ec2_client import EC2 from dataall.modules.mlstudio.aws.sagemaker_studio_client import get_sagemaker_studio_domain logger = logging.getLogger(__name__) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index c99e11eca..89742b584 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -19,7 +19,7 @@ class SagemakerStudioDomain(Resource, Base): SagemakerStudioStatus = Column(String, nullable=True) sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) - RoleArn = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) region = Column(String, default='eu-west-1') vpcType = Column(String, nullable=True) vpcId = Column(String, nullable=True) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index fadfff920..7fa24b995 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -2,6 +2,7 @@ DAO layer that encapsulates the logic and interaction with the database for ML Studio Provides the API to retrieve / update / delete ml studio """ +import stat from typing import Optional from sqlalchemy import or_ from sqlalchemy.sql import and_ @@ -11,6 +12,7 @@ from dataall.base.db import paginate from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser from dataall.core.environment.services.environment_resource_manager import EnvironmentResource +from dataall.core.environment.services.environment_service import EnvironmentService from dataall.base.utils.naming_convention import ( NamingConventionService, NamingConventionPattern, @@ -22,16 +24,22 @@ class SageMakerStudioRepository(EnvironmentResource): _DEFAULT_PAGE = 1 _DEFAULT_PAGE_SIZE = 10 - def __init__(self, session): - self._session = session + @staticmethod + def update_env(session, environment): + current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) + previous_mlstudio_enabled = True if domain else False + return current_mlstudio_enabled != previous_mlstudio_enabled - def save_sagemaker_studio_user(self, user): + @staticmethod + def save_sagemaker_studio_user(session, user): """Save SageMaker Studio user to the database""" - self._session.add(user) - self._session.commit() + session.add(user) + session.commit() - def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: - query = self._session.query(SagemakerStudioUser).filter( + @staticmethod + def _query_user_sagemaker_studio_users(session, username, groups, filter) -> Query: + query = session.query(SagemakerStudioUser).filter( or_( SagemakerStudioUser.owner == username, SagemakerStudioUser.SamlAdminGroupName.in_(groups), @@ -50,21 +58,24 @@ def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: ) return query - def paginated_sagemaker_studio_users(self, username, groups, filter={}) -> dict: + @staticmethod + def paginated_sagemaker_studio_users(session, username, groups, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( - query=self._query_user_sagemaker_studio_users(username, groups, filter), + query=SageMakerStudioRepository._query_user_sagemaker_studio_users(session, username, groups, filter), page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), ).to_dict() - def find_sagemaker_studio_user(self, uri): + @staticmethod + def find_sagemaker_studio_user(session, uri): """Finds a sagemaker studio user. Returns None if it doesn't exist""" - return self._session.query(SagemakerStudioUser).get(uri) + return session.query(SagemakerStudioUser).get(uri) - def count_resources(self, environment, group_uri): + @staticmethod + def count_resources(session, environment, group_uri): return ( - self._session.query(SagemakerStudioUser) + session.query(SagemakerStudioUser) .filter( and_( SagemakerStudioUser.environmentUri == environment.environmentUri, @@ -74,10 +85,10 @@ def count_resources(self, environment, group_uri): .count() ) - def create_sagemaker_studio_domain(self, username, environment, data): - # TODO: BUILD ROLE ARN Domain Name + @staticmethod + def create_sagemaker_studio_domain(session, username, environment, data): domain = SagemakerStudioDomain( - label=data.get('label'), + label=f"{data.get('label')}-domain", owner=username, description=data.get('description', 'No description provided'), tags=data.get('tags', []), @@ -91,8 +102,8 @@ def create_sagemaker_studio_domain(self, username, environment, data): vpcId=data.get('vpcId'), subnetIds=data.get('subnetIds', []) ) - self._session.add(domain) - self._session.commit() + session.add(domain) + session.commit() domain.sagemakerStudioDomainName = NamingConventionService( target_uri=domain.sagemakerStudioUri, @@ -110,16 +121,18 @@ def create_sagemaker_studio_domain(self, username, environment, data): return domain - def paginated_environment_sagemaker_studio_domains(self, uri, filter={}) -> dict: + @staticmethod + def paginated_environment_sagemaker_studio_domains(session, uri, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( - query=self._query_environment_sagemaker_studio_domains(uri, filter), + query=SageMakerStudioRepository._query_environment_sagemaker_studio_domains(session, uri, filter), page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), ).to_dict() - def _query_environment_sagemaker_studio_domains(self, uri, filter) -> Query: - query = self._session.query(SagemakerStudioDomain).filter( + @staticmethod + def _query_environment_sagemaker_studio_domains(session, uri, filter) -> Query: + query = session.query(SagemakerStudioDomain).filter( SagemakerStudioDomain.environmentUri == uri, ) if filter and filter.get('term'): @@ -135,9 +148,6 @@ def _query_environment_sagemaker_studio_domains(self, uri, filter) -> Query: ) return query - def find_sagemaker_studio_domain(self, uri) -> Optional[SagemakerStudioDomain]: - return self._session.query(SagemakerStudioDomain).get(uri) - @staticmethod def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index a7eb42e6b..ae6152142 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -19,7 +19,7 @@ from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -from dataall.modules.mlstudio.aws.ec2_client import EC2 +from dataall.core.environment.aws.ec2_client import EC2 from dataall.base.aws.sts import SessionHelper from dataall.modules.mlstudio.services.mlstudio_permissions import ( @@ -88,7 +88,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak ) # FOR NEW ONES (default, created, imported) - # - CHECK RDS FIRST + # - CHECK RDS FIRST - ONLY GET DOMAIN NAME # - IF NOT BOTO3 existing_domain = response.get('DomainId', False) @@ -114,7 +114,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak SamlAdminGroupName=admin_group, tags=request.tags, ) - SageMakerStudioRepository(session).save_sagemaker_studio_user(user=sagemaker_studio_user) + SageMakerStudioRepository.save_sagemaker_studio_user(session, sagemaker_studio_user) ResourcePolicy.attach_resource_policy( session=session, @@ -178,12 +178,12 @@ def create_sagemaker_studio_domain(*, uri: str, data: dict): else: data["vpcType"] = "created" - domain = SageMakerStudioRepository(session).create_sagemaker_studio_domain( + domain = SageMakerStudioRepository.create_sagemaker_studio_domain( + session=session, username=get_context().username, environment=environment, data=data, ) - stack_helper.deploy_stack(domain.environmentUri) return domain @staticmethod @@ -196,29 +196,28 @@ def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn vpc_id=data.get("vpcId", None), subnet_ids=data.get('subnetIds', []), ) - data["vpcType"] = "imported" return True - @staticmethod - def _get_domain_env_uri(session, uri): - domain = SagemakerStudioService._get_sagemaker_studio_domain(session, uri) - return domain.environmentUri - @staticmethod @has_tenant_permission(permissions.MANAGE_ENVIRONMENTS) - @has_resource_permission(permissions.UPDATE_ENVIRONMENT, parent_resource=_get_domain_env_uri) - def delete_sagemaker_studio_domain(*, uri: str): + @has_resource_permission(permissions.UPDATE_ENVIRONMENT) + def delete_environment_sagemaker_studio_domain(*, uri: str): with _session() as session: - domain = SagemakerStudioService._get_sagemaker_studio_domain(session, uri) + domain = SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=uri) # TODO: CHECK NUMBER OF USERS BEFORE DELETE session.delete(domain) - stack_helper.deploy_stack(domain.environmentUri) return True + @staticmethod + def get_environment_sagemaker_studio_domain(*, environment_uri: str): + with _session() as session: + return SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment_uri) + @staticmethod def list_environment_sagemaker_studio_domains(*, filter: dict, environment_uri: str) -> dict: with _session() as session: - return SageMakerStudioRepository(session).paginated_environment_sagemaker_studio_domains( + return SageMakerStudioRepository.paginated_environment_sagemaker_studio_domains( + session=session, uri=environment_uri, filter=filter, ) @@ -226,7 +225,8 @@ def list_environment_sagemaker_studio_domains(*, filter: dict, environment_uri: @staticmethod def list_sagemaker_studio_users(*, filter: dict) -> dict: with _session() as session: - return SageMakerStudioRepository(session).paginated_sagemaker_studio_users( + return SageMakerStudioRepository.paginated_sagemaker_studio_users( + session=session, username=get_context().username, groups=get_context().groups, filter=filter, @@ -285,14 +285,8 @@ def delete_sagemaker_studio_user(*, uri: str, delete_from_aws: bool): @staticmethod def _get_sagemaker_studio_user(session, uri): - user = SageMakerStudioRepository(session).find_sagemaker_studio_user(uri=uri) + user = SageMakerStudioRepository.find_sagemaker_studio_user(session=session, uri=uri) if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user - @staticmethod - def _get_sagemaker_studio_domain(session, uri): - domain = SageMakerStudioRepository(session).find_sagemaker_studio_domain(uri=uri) - if not domain: - raise exceptions.ObjectNotFound('SagemakerStudioDomain', uri) - return domain diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index f3e979c9b..4d4aef376 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -1,101 +1,39 @@ -import { LoadingButton } from '@mui/lab'; -import { DeleteOutlined } from '@mui/icons-material'; import { Box, Card, CardHeader, - Chip, Divider, Grid, - IconButton, - Table, - TableBody, - TableCell, - TableHead, - TableRow + CardContent, + Typography, + CircularProgress } from '@mui/material'; -import CircularProgress from '@mui/material/CircularProgress'; -import { useSnackbar } from 'notistack'; + import PropTypes from 'prop-types'; import React, { useCallback, useEffect, useState } from 'react'; -import { Defaults, Pager, PlusIcon, RefreshTableMenu, Scrollbar } from 'design'; +import { RefreshTableMenu, ObjectMetadata } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { useClient } from 'services'; -import { - deleteMLStudioDomain, - listEnvironmentMLStudioDomains -} from '../services'; -import { MLStudioDomainCreateModal } from './MLStudioDomainCreateModal'; +import { getEnvironmentMLStudioDomain } from '../services'; -function DomainRow({ domain, deleteEnvironmentMLStudioDomain }) { - return ( - - {domain.label} - {domain.sagemakerStudioDomainName} - {domain.vpcId} - - {domain.subnetIds && ( - - {domain.subnetIds.map((subnet) => ( - - ))} - - )} - - - { - deleteEnvironmentMLStudioDomain(domain.sagemakerStudioUri); - }} - > - - - - - ); -} - -DomainRow.propTypes = { - domain: PropTypes.any, - deleteEnvironmentMLStudioDomain: PropTypes.func -}; export const EnvironmentMLStudio = ({ environment }) => { const client = useClient(); const dispatch = useDispatch(); - const { enqueueSnackbar } = useSnackbar(); - const [items, setItems] = useState(Defaults.pagedResponse); - const [filter, setFilter] = useState(Defaults.filter); + const [mlStudioDomain, setMLStudioDomain] = useState(null); const [loading, setLoading] = useState(true); - const [isStudioDomainCreateOpen, setStudioDomainCreateOpen] = useState(false); - const handleStudioDomainCreateModalOpen = () => { - setStudioDomainCreateOpen(true); - }; - const handleStudioDomainCreateModalClose = () => { - setStudioDomainCreateOpen(false); - }; - - const fetchItems = useCallback(async () => { + const fetchMLStudioDomain = useCallback(async () => { try { + setLoading(true); const response = await client.query( - listEnvironmentMLStudioDomains({ - environmentUri: environment.environmentUri, - filter + getEnvironmentMLStudioDomain({ + environmentUri: environment.environmentUri }) ); if (!response.errors) { - setItems({ ...response.data.listEnvironmentMLStudioDomains }); + if (response.data.getEnvironmentMLStudioDomain) { + setMLStudioDomain(response.data.getEnvironmentMLStudioDomain); + } } else { dispatch({ type: SET_ERROR, error: response.errors[0].message }); } @@ -104,63 +42,26 @@ export const EnvironmentMLStudio = ({ environment }) => { } finally { setLoading(false); } - }, [client, dispatch, filter, environment.environmentUri]); - - const deleteEnvironmentMLStudioDomain = async (sagemakerStudioUri) => { - const response = await client.mutate( - deleteMLStudioDomain({ sagemakerStudioUri }) - ); - if (!response.errors) { - enqueueSnackbar('ML Studio Domain deleted', { - anchorOrigin: { - horizontal: 'right', - vertical: 'top' - }, - variant: 'success' - }); - fetchItems().catch((e) => - dispatch({ type: SET_ERROR, error: e.message }) - ); - } else { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - }; + }, [client, dispatch, environment.environmentUri]); useEffect(() => { if (client) { - fetchItems().catch((e) => + fetchMLStudioDomain().catch((e) => dispatch({ type: SET_ERROR, error: e.message }) ); } - }, [client, filter.page, fetchItems, dispatch]); + }, [client, fetchMLStudioDomain, dispatch]); - const handlePageChange = async (event, value) => { - if (value <= items.pages && value !== items.page) { - await setFilter({ ...filter, page: value }); - } - }; + if (loading) { + return ; + } return ( } - title={ - - ML Studio Domains - {items.nodes.length === 0 && ( - } - sx={{ m: 1 }} - variant="outlined" - > - Add ML Studio Domain - - )} - - } + action={} + title={ML Studio Domain} /> { > - + {mlStudioDomain === null ? ( + + + No ML Studio Domain - To Create a ML Studio Domain for this + Environment: `{environment.label}`, edit the Environment and + enable the ML Studio Environment Feature + + + ) : ( + + + + + + + + SageMaker ML Studio Domain Name + + + {mlStudioDomain.sagemakerStudioDomainName} + + + + + SageMaker ML Studio Default Execution Role + + + arn:aws:s3::: + {mlStudioDomain.DefaultDomainRoleName} + + + + + Domain VPC Type + + + {mlStudioDomain.vpcType} + + + {mlStudioDomain.vpcType === 'imported' && ( + <> + + + Domain VPC Id + + + {mlStudioDomain.vpcId} + + + + + Domain Subnet Ids + + + {mlStudioDomain.subnetIds} + + + + )} + + + + + + + )} + {/* @@ -183,7 +160,6 @@ export const EnvironmentMLStudio = ({ environment }) => { Domain Name VPC Subnets - Actions {loading ? ( @@ -196,9 +172,6 @@ export const EnvironmentMLStudio = ({ environment }) => { domain={domain} environment={environment} fetchItems={fetchItems} - deleteEnvironmentMLStudioDomain={ - deleteEnvironmentMLStudioDomain - } /> )) ) : ( @@ -218,17 +191,8 @@ export const EnvironmentMLStudio = ({ environment }) => { /> )} - + */} - {isStudioDomainCreateOpen && ( - - )} ); }; diff --git a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js b/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js deleted file mode 100644 index d592a4d67..000000000 --- a/frontend/src/modules/Environments/components/MLStudioDomainCreateModal.js +++ /dev/null @@ -1,204 +0,0 @@ -import { LoadingButton } from '@mui/lab'; -import { - Box, - CardContent, - CardHeader, - Dialog, - FormHelperText, - Grid, - TextField, - Typography -} from '@mui/material'; -import { Formik } from 'formik'; -import { useSnackbar } from 'notistack'; -import PropTypes from 'prop-types'; -import * as Yup from 'yup'; -import { ChipInput } from 'design'; -import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; -import { createMLStudioDomain } from '../services'; - -export const MLStudioDomainCreateModal = (props) => { - const { environment, onApply, onClose, open, reloadStudioDomains, ...other } = - props; - const { enqueueSnackbar } = useSnackbar(); - const dispatch = useDispatch(); - const client = useClient(); - - async function submit(values, setStatus, setSubmitting, setErrors) { - try { - const response = await client.mutate( - createMLStudioDomain({ - environmentUri: environment.environmentUri, - label: values.label, - vpcId: values.mlStudioVPCId, - subnetIds: values.mlStudioSubnetIds - }) - ); - if (!response.errors) { - setStatus({ success: true }); - setSubmitting(false); - enqueueSnackbar('ML Studio Domain Added', { - anchorOrigin: { - horizontal: 'right', - vertical: 'top' - }, - variant: 'success' - }); - if (reloadStudioDomains) { - reloadStudioDomains(); - } - if (onApply) { - onApply(); - } - } else { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } catch (err) { - setStatus({ success: false }); - setErrors({ submit: err.message }); - setSubmitting(false); - dispatch({ type: SET_ERROR, error: err.message }); - } - } - - if (!environment) { - return null; - } - - return ( - - - - Create a SageMaker ML Studio Domain for your Environment - - - !!value, - then: Yup.array() - .min(1) - .required('At least 1 Subnet Id required if VPC Id specified') - }) - })} - onSubmit={async ( - values, - { setErrors, setStatus, setSubmitting } - ) => { - await submit(values, setStatus, setSubmitting, setErrors); - }} - > - {({ - errors, - handleBlur, - handleChange, - handleSubmit, - isSubmitting, - setFieldValue, - touched, - values - }) => ( -
- - - - - - - - - - - - { - setFieldValue('mlStudioSubnetIds', [...chip]); - }} - /> - - - {errors.submit && ( - - {errors.submit} - - )} - - - Create - - - - - - )} -
-
-
-
- ); -}; - -MLStudioDomainCreateModal.propTypes = { - environment: PropTypes.object.isRequired, - onApply: PropTypes.func, - onClose: PropTypes.func, - reloadStudioDomains: PropTypes.func, - open: PropTypes.bool.isRequired -}; diff --git a/frontend/src/modules/Environments/components/index.js b/frontend/src/modules/Environments/components/index.js index e8e41c362..7aecd51fa 100644 --- a/frontend/src/modules/Environments/components/index.js +++ b/frontend/src/modules/Environments/components/index.js @@ -13,4 +13,3 @@ export * from './EnvironmentTeamInviteForm'; export * from './EnvironmentTeams'; export * from './NetworkCreateModal'; export * from './EnvironmentMLStudio'; -export * from './MLStudioDomainCreateModal'; diff --git a/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js b/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js new file mode 100644 index 000000000..7abdc7e9e --- /dev/null +++ b/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js @@ -0,0 +1,12 @@ +import { gql } from 'apollo-boost'; + +export const deleteEnvironmentMLStudioDomain = ({ environmentUri }) => ({ + variables: { + environmentUri + }, + mutation: gql` + mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { + deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) + } + ` +}); diff --git a/frontend/src/modules/Environments/services/deleteMLStudioDomain.js b/frontend/src/modules/Environments/services/deleteMLStudioDomain.js deleted file mode 100644 index 2a0c6e7d7..000000000 --- a/frontend/src/modules/Environments/services/deleteMLStudioDomain.js +++ /dev/null @@ -1,12 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const deleteMLStudioDomain = ({ sagemakerStudioUri }) => ({ - variables: { - sagemakerStudioUri - }, - mutation: gql` - mutation deleteMLStudioDomain($sagemakerStudioUri: String!) { - deleteMLStudioDomain(sagemakerStudioUri: $sagemakerStudioUri) - } - ` -}); diff --git a/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js b/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js new file mode 100644 index 000000000..9dc34d630 --- /dev/null +++ b/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js @@ -0,0 +1,23 @@ +import { gql } from 'apollo-boost'; + +export const getEnvironmentMLStudioDomain = ({ environmentUri }) => ({ + variables: { + environmentUri + }, + query: gql` + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + ` +}); diff --git a/frontend/src/modules/Environments/services/index.js b/frontend/src/modules/Environments/services/index.js index dbdd41431..acb89e0fe 100644 --- a/frontend/src/modules/Environments/services/index.js +++ b/frontend/src/modules/Environments/services/index.js @@ -23,5 +23,6 @@ export * from './removeGroup'; export * from './updateEnvironment'; export * from './updateGroupEnvironmentPermissions'; export * from './createMLStudioDomain'; -export * from './deleteMLStudioDomain'; +export * from './deleteEnvironmentMLStudioDomain'; export * from './listEnvironmentMLStudioDomains'; +export * from './getEnvironmentMLStudioDomain'; diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 4a97500d6..467a2cd42 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -31,6 +31,14 @@ import { CopyToClipboard } from 'react-copy-to-clipboard/lib/Component'; import { Helmet } from 'react-helmet-async'; import { Link as RouterLink, useNavigate, useParams } from 'react-router-dom'; import * as Yup from 'yup'; +import { + createMLStudioDomain, + createEnvironment, + getPivotRoleExternalId, + getPivotRoleName, + getPivotRolePresignedUrl, + getCDKExecPolicyPresignedUrl +} from '../services'; import { ArrowLeftIcon, ChevronRightIcon, @@ -44,13 +52,6 @@ import { useClient, useGroups } from 'services'; -import { - createEnvironment, - getPivotRoleExternalId, - getPivotRoleName, - getPivotRolePresignedUrl, - getCDKExecPolicyPresignedUrl -} from '../services'; import { AwsRegions, isAnyEnvironmentModuleEnabled, @@ -179,6 +180,8 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, + mlStudioVPCId: values.mlStudioVPCId, + mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -200,6 +203,19 @@ const EnvironmentCreateForm = (props) => { }) ); if (!response.errors) { + if (values.mlStudiosEnabled === true) { + const response2 = await client.mutate( + createMLStudioDomain({ + environmentUri: response.data.createEnvironment.environmentUri, + label: values.label, + vpcId: values.mlStudioVPCId, + subnetIds: values.mlStudioSubnetIds + }) + ); + if (response2.errors) { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } setStatus({ success: true }); setSubmitting(false); enqueueSnackbar('Environment Created', { @@ -484,7 +500,9 @@ const EnvironmentCreateForm = (props) => { mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO), pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', - resourcePrefix: 'dataall' + resourcePrefix: 'dataall', + mlStudioVPCId: '', + mlStudioSubnetIds: [] }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -508,6 +526,15 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), + mlStudioVPCId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() @@ -859,6 +886,51 @@ const EnvironmentCreateForm = (props) => { + {values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('mlStudioSubnetIds', [...chip]); + }} + /> + + + + )} {errors.submit && ( {errors.submit} diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index caa5d8441..3d38b2196 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -31,7 +31,13 @@ import { } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { useClient } from 'services'; -import { getEnvironment, updateEnvironment } from '../services'; +import { + getEnvironment, + updateEnvironment, + getEnvironmentMLStudioDomain, + createMLStudioDomain, + deleteEnvironmentMLStudioDomain +} from '../services'; import { isAnyEnvironmentModuleEnabled, isModuleEnabled, @@ -47,6 +53,9 @@ const EnvironmentEditForm = (props) => { const { settings } = useSettings(); const [loading, setLoading] = useState(true); const [env, setEnv] = useState(''); + const [envMLStudioDomain, setEnvMLStudioDomain] = useState(''); + const [previousEnvMLStudioEnabled, setPreviousEnvMLStudioEnabled] = + useState(false); const fetchItem = useCallback(async () => { const response = await client.query( @@ -58,6 +67,20 @@ const EnvironmentEditForm = (props) => { environment.parameters.map((x) => [x.key, x.value]) ); setEnv(environment); + if (environment.parameters['mlStudiosEnabled'] === 'true') { + setPreviousEnvMLStudioEnabled(true); + const response2 = await client.query( + getEnvironmentMLStudioDomain({ environmentUri: params.uri }) + ); + if (!response2.errors && response2.data.getEnvironmentMLStudioDomain) { + setEnvMLStudioDomain(response2.data.getEnvironmentMLStudioDomain); + } else { + const error = response2.errors + ? response2.errors[0].message + : 'Environment ML Studio Domain not found'; + dispatch({ type: SET_ERROR, error }); + } + } } else { const error = response.errors ? response.errors[0].message @@ -81,6 +104,8 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, + mlStudioVPCId: values.mlStudioVPCId, + mlStudioSubnetIds: values.mlStudioSubnetIds, parameters: [ { key: 'notebooksEnabled', @@ -103,6 +128,36 @@ const EnvironmentEditForm = (props) => { }) ); if (!response.errors) { + if ( + values.mlStudiosEnabled !== previousEnvMLStudioEnabled && + values.mlStudiosEnabled === true + ) { + const response2 = await client.mutate( + createMLStudioDomain({ + environmentUri: env.environmentUri, + label: values.label, + vpcId: values.mlStudioVPCId, + subnetIds: values.mlStudioSubnetIds + }) + ); + if (response2.errors) { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } + if ( + values.mlStudiosEnabled !== previousEnvMLStudioEnabled && + values.mlStudiosEnabled === false + ) { + console.error(envMLStudioDomain.sagemakerStudioUri); + const response2 = await client.mutate( + deleteEnvironmentMLStudioDomain({ + environmentUri: envMLStudioDomain.environmentUri + }) + ); + if (response2.errors) { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } setStatus({ success: true }); setSubmitting(false); enqueueSnackbar('Environment updated', { @@ -213,6 +268,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], + mlStudioVPCId: envMLStudioDomain.vpcId, + mlStudioSubnetIds: envMLStudioDomain.subnetIds, notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -226,6 +283,15 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), + mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), + mlStudioVPCId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -383,6 +449,58 @@ const EnvironmentEditForm = (props) => { + {!previousEnvMLStudioEnabled && + values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('mlStudioSubnetIds', [ + ...chip + ]); + }} + /> + + + + )} {isAnyEnvironmentModuleEnabled() && ( From 6909b9d94372f375213ef04a1791107d2428e47f Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Fri, 1 Dec 2023 16:26:04 -0500 Subject: [PATCH 13/38] Remove unused API list domains --- .../dataall/modules/mlstudio/api/queries.py | 11 ----- .../dataall/modules/mlstudio/api/resolvers.py | 7 ---- .../mlstudio/db/mlstudio_repositories.py | 9 ---- .../mlstudio/services/mlstudio_service.py | 10 ----- .../components/EnvironmentMLStudio.js | 41 ------------------- .../modules/Environments/services/index.js | 1 - .../listEnvironmentMLStudioDomains.js | 34 --------------- 7 files changed, 113 deletions(-) delete mode 100644 frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index 41c4e5cd1..ee014839f 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -4,7 +4,6 @@ get_sagemaker_studio_user, list_sagemaker_studio_users, get_sagemaker_studio_user_presigned_url, - list_environment_sagemaker_studio_domains, get_environment_sagemaker_studio_domain ) @@ -45,13 +44,3 @@ type=gql.Ref('SagemakerStudioDomain'), resolver=get_environment_sagemaker_studio_domain, ) - -listEnvironmentMLStudioDomains = gql.QueryField( - name='listEnvironmentMLStudioDomains', - args=[ - gql.Argument('filter', gql.Ref('SagemakerStudioDomainFilter')), - gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), - ], - type=gql.Ref('SagemakerStudioDomainSearchResult'), - resolver=list_environment_sagemaker_studio_domains, -) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index 2d20d41ff..e4a41a73c 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -125,13 +125,6 @@ def get_environment_sagemaker_studio_domain(context, source, environmentUri: str return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri) -def list_environment_sagemaker_studio_domains(context, source, filter: dict = None, environmentUri: str = None): - RequestValidator.required_uri(environmentUri) - if not filter: - filter = {} - return SagemakerStudioService.list_environment_sagemaker_studio_domains(filter=filter, environment_uri=environmentUri) - - def resolve_user_role(context: Context, source: SagemakerStudioUser): """ Resolves the role of the current user in reference with the SageMaker Studio User diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 7fa24b995..49d44c2c9 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -121,15 +121,6 @@ def create_sagemaker_studio_domain(session, username, environment, data): return domain - @staticmethod - def paginated_environment_sagemaker_studio_domains(session, uri, filter={}) -> dict: - """Returns a page of sagemaker studio users for a data.all user""" - return paginate( - query=SageMakerStudioRepository._query_environment_sagemaker_studio_domains(session, uri, filter), - page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), - page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), - ).to_dict() - @staticmethod def _query_environment_sagemaker_studio_domains(session, uri, filter) -> Query: query = session.query(SagemakerStudioDomain).filter( diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index ae6152142..ab3794201 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -213,15 +213,6 @@ def get_environment_sagemaker_studio_domain(*, environment_uri: str): with _session() as session: return SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment_uri) - @staticmethod - def list_environment_sagemaker_studio_domains(*, filter: dict, environment_uri: str) -> dict: - with _session() as session: - return SageMakerStudioRepository.paginated_environment_sagemaker_studio_domains( - session=session, - uri=environment_uri, - filter=filter, - ) - @staticmethod def list_sagemaker_studio_users(*, filter: dict) -> dict: with _session() as session: @@ -289,4 +280,3 @@ def _get_sagemaker_studio_user(session, uri): if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user - diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 4d4aef376..158adf3d9 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -151,47 +151,6 @@ export const EnvironmentMLStudio = ({ environment }) => { )} - {/* - -
- - - Name - Domain Name - VPC - Subnets - - - {loading ? ( - - ) : ( - - {items.nodes.length > 0 ? ( - items.nodes.map((domain) => ( - - )) - ) : ( - - No SageMaker Studio Domain Found - - )} - - )} -
- {!loading && items.nodes.length > 0 && ( - - )} -
-
*/}
); diff --git a/frontend/src/modules/Environments/services/index.js b/frontend/src/modules/Environments/services/index.js index acb89e0fe..8955122b0 100644 --- a/frontend/src/modules/Environments/services/index.js +++ b/frontend/src/modules/Environments/services/index.js @@ -24,5 +24,4 @@ export * from './updateEnvironment'; export * from './updateGroupEnvironmentPermissions'; export * from './createMLStudioDomain'; export * from './deleteEnvironmentMLStudioDomain'; -export * from './listEnvironmentMLStudioDomains'; export * from './getEnvironmentMLStudioDomain'; diff --git a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js b/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js deleted file mode 100644 index a98dba4b2..000000000 --- a/frontend/src/modules/Environments/services/listEnvironmentMLStudioDomains.js +++ /dev/null @@ -1,34 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const listEnvironmentMLStudioDomains = ({ filter, environmentUri }) => ({ - variables: { - environmentUri, - filter - }, - query: gql` - query listEnvironmentMLStudioDomains( - $filter: SagemakerStudioDomainFilter - $environmentUri: String! - ) { - listEnvironmentMLStudioDomains( - environmentUri: $environmentUri - filter: $filter - ) { - count - page - pages - hasNext - hasPrevious - nodes { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - sagemakerStudioDomainName - } - } - } - ` -}); From 50a90b328ac11e7aa14e9ab1871a9005e1e3a509 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Fri, 1 Dec 2023 18:21:50 -0500 Subject: [PATCH 14/38] add tests mlstudio domain apis --- .../modules/mlstudio/db/mlstudio_models.py | 2 +- tests/core/conftest.py | 1 - tests/core/environments/test_environment.py | 29 +---- tests/core/vpc/test_vpc.py | 4 +- tests/modules/mlstudio/cdk/conftest.py | 20 ++- .../cdk/test_sagemaker_studio_stack.py | 9 +- tests/modules/mlstudio/conftest.py | 78 ++++++++++- .../modules/mlstudio/test_sagemaker_studio.py | 123 ++++++++++++++++++ 8 files changed, 229 insertions(+), 37 deletions(-) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 89742b584..28ff62e9e 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -15,7 +15,7 @@ class SagemakerStudioDomain(Resource, Base): sagemakerStudioUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudio') ) - sagemakerStudioDomainID = Column(String, nullable=False) + sagemakerStudioDomainID = Column(String, nullable=True) SagemakerStudioStatus = Column(String, nullable=True) sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 6d8a449e4..738ab4d06 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -44,7 +44,6 @@ def factory(org, envname, owner, group, account, region, desc='test', parameters 'tags': ['a', 'b', 'c'], 'region': f'{region}', 'SamlGroupName': f'{group}', - 'vpcId': 'vpc-123456', 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()] }, ) diff --git a/tests/core/environments/test_environment.py b/tests/core/environments/test_environment.py index 31ba18e57..e806e07b2 100644 --- a/tests/core/environments/test_environment.py +++ b/tests/core/environments/test_environment.py @@ -221,26 +221,6 @@ def test_list_environments_no_filter(org_fixture, env_fixture, client, group): assert response.data.listEnvironments.count == 1 - response = client.query( - """ - query ListEnvironmentNetworks($environmentUri: String!,$filter:VpcFilter){ - listEnvironmentNetworks(environmentUri:$environmentUri,filter:$filter){ - count - nodes{ - VpcId - SamlGroupName - } - } - } - """, - environmentUri=env_fixture.environmentUri, - username='alice', - groups=[group.name], - ) - print(response) - - assert response.data.listEnvironmentNetworks.count == 1 - def test_list_environment_role_filter_as_creator(org_fixture, env_fixture, client, group): response = client.query( @@ -656,23 +636,16 @@ def test_create_environment(db, client, org_fixture, env_fixture, user, group): 'tags': ['a', 'b', 'c'], 'region': f'{env_fixture.region}', 'SamlGroupName': group.name, - 'vpcId': 'vpc-1234567', - 'privateSubnetIds': 'subnet-1', - 'publicSubnetIds': 'subnet-21', 'resourcePrefix': 'customer-prefix', }, ) body = response.data.createEnvironment - assert body.networks + assert len(body.networks) == 0 assert body.EnvironmentDefaultIAMRoleName == 'myOwnIamRole' assert body.EnvironmentDefaultIAMRoleImported assert body.resourcePrefix == 'customer-prefix' - for vpc in body.networks: - assert vpc.privateSubnetIds - assert vpc.publicSubnetIds - assert vpc.default with db.scoped_session() as session: env = EnvironmentService.get_environment_by_uri( diff --git a/tests/core/vpc/test_vpc.py b/tests/core/vpc/test_vpc.py index a55196d32..8f2391220 100644 --- a/tests/core/vpc/test_vpc.py +++ b/tests/core/vpc/test_vpc.py @@ -60,7 +60,7 @@ def test_list_networks(client, env_fixture, db, user, group, vpc): ) print(response) - assert response.data.listEnvironmentNetworks.count == 2 + assert response.data.listEnvironmentNetworks.count == 1 def test_list_networks_nopermissions(client, env_fixture, db, user, group2, vpc): @@ -119,4 +119,4 @@ def test_delete_network(client, env_fixture, db, user, group, module_mocker, vpc username='alice', groups=[group.name], ) - assert len(response.data.listEnvironmentNetworks['nodes']) == 1 + assert len(response.data.listEnvironmentNetworks['nodes']) == 0 diff --git a/tests/modules/mlstudio/cdk/conftest.py b/tests/modules/mlstudio/cdk/conftest.py index 4b3327838..718f2023f 100644 --- a/tests/modules/mlstudio/cdk/conftest.py +++ b/tests/modules/mlstudio/cdk/conftest.py @@ -2,7 +2,7 @@ from dataall.core.environment.db.environment_models import Environment from dataall.core.organizations.db.organization_models import Organization -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -23,3 +23,21 @@ def sgm_studio(db, env_fixture: Environment) -> SagemakerStudioUser: ) session.add(sm_user) yield sm_user + +@pytest.fixture(scope='module', autouse=True) +def sgm_studio_domain(db, env_fixture: Environment) -> SagemakerStudioDomain: + with db.scoped_session() as session: + sm_domain = SagemakerStudioDomain( + label='sagemaker-domain', + owner='me', + environmentUri=env_fixture.environmentUri, + AWSAccountId=env_fixture.AwsAccountId, + region=env_fixture.region, + SagemakerStudioStatus="PENDING", + DefaultDomainRoleName="DefaultMLStudioRole", + sagemakerStudioDomainName="DomainName", + vpcType="created" + ) + session.add(sm_domain) + yield sm_domain + \ No newline at end of file diff --git a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py index a2c1752e2..4ea1f34fa 100644 --- a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py +++ b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py @@ -66,16 +66,19 @@ def patch_methods_sagemaker_studio(mocker, db, sgm_studio, env_fixture, org_fixt @pytest.fixture(scope='function', autouse=True) -def patch_methods_sagemaker_studio_extension(mocker): +def patch_methods_sagemaker_studio_extension(mocker, sgm_studio_domain): mocker.patch( 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", ) mocker.patch( - 'dataall.modules.mlstudio.aws.ec2_client.EC2.check_default_vpc_exists', + 'dataall.core.environment.aws.ec2_client.EC2.check_default_vpc_exists', return_value=False, ) - + mocker.patch( + 'dataall.modules.mlstudio.db.mlstudio_repositories.SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri', + return_value=sgm_studio_domain, + ) def test_resources_sgmstudio_stack_created(sgm_studio): app = App() diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 433048894..8f105be36 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -1,6 +1,6 @@ import pytest -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -16,6 +16,29 @@ def env_params(): yield {'mlStudiosEnabled': 'True'} +@pytest.fixture(scope='module', autouse=True) +def get_cdk_look_up_role_arn(module_mocker): + module_mocker.patch( + 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', + return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", + ) + +@pytest.fixture(scope='module', autouse=True) +def check_default_vpc(module_mocker): + module_mocker.patch( + 'dataall.core.environment.aws.ec2_client.EC2.check_default_vpc_exists', + return_value=False, + ) + + +@pytest.fixture(scope='module', autouse=True) +def check_vpc_exists(module_mocker): + module_mocker.patch( + 'dataall.core.environment.aws.ec2_client.EC2.check_vpc_exists', + return_value=True, + ) + + @pytest.fixture(scope='module') def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudioUser: response = client.query( @@ -79,3 +102,56 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): response.data.createSagemakerStudioUser.environmentUri == env_fixture.environmentUri ) + + +@pytest.fixture(scope='module') +def sagemaker_studio_domain(client, group, env_fixture) -> SagemakerStudioDomain: + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + }, + username='alice', + groups=[group.name], + ) + yield response.data.createMLStudioDomain + + +@pytest.fixture(scope='module') +def sagemaker_studio_domain_with_vpc(client, group, env_fixture) -> SagemakerStudioDomain: + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + 'vpcId': 'vpc-12345', + 'subnetIds': ['subnet-12345', 'subnet-67890'] + }, + username='alice', + groups=[group.name], + ) + + yield response.data.createMLStudioDomain \ No newline at end of file diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py index c55762522..bbdf80796 100644 --- a/tests/modules/mlstudio/test_sagemaker_studio.py +++ b/tests/modules/mlstudio/test_sagemaker_studio.py @@ -67,3 +67,126 @@ def test_delete_sagemaker_studio_user( sagemaker_studio_user.sagemakerStudioUserUri ) assert not n + +# @pytest.fixture +# def mock_check_domain_vpc(mocker): +# return mocker.patch( +# 'dataall.modules.mlstudio.services.mlstudio_service.SagemakerStudioService.check_mlstudio_domain_vpc', +# return_value=False, +# ) + + +def test_create_sagemaker_studio_domain(sagemaker_studio_domain, env_fixture): + """Testing that the conftest sagemaker studio domain has been created correctly""" + assert sagemaker_studio_domain.label == 'testcreate-domain' + assert sagemaker_studio_domain.vpcType == 'created' + assert sagemaker_studio_domain.vpcId is None + assert len(sagemaker_studio_domain.subnetIds) == 0 + assert sagemaker_studio_domain.environmentUri == env_fixture.environmentUri + + +def test_create_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + }, + username='anonymoususer', + groups=[group2.name], + ) + assert 'Unauthorized' in response.errors[0].message + + +def test_get_sagemaker_studio_domain(client, env_fixture, sagemaker_studio_domain): + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_fixture.environmentUri, + ) + print(response.data) + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri == sagemaker_studio_domain.sagemakerStudioUri + + +def test_delete_sagemaker_studio_domain(client, env_fixture, group): + response = client.query( + """ + mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { + deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) + } + """, + environmentUri=env_fixture.environmentUri, + username='alice', + groups=[group.name], + ) + assert response.data.deleteEnvironmentMLStudioDomain + + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_fixture.environmentUri + ) + assert response.data.getEnvironmentMLStudioDomain is None + + +def test_create_sagemaker_studio_domain_with_vpc(sagemaker_studio_domain_with_vpc, env_fixture): + """Testing that the conftest sagemaker studio domain has been created correctly""" + assert sagemaker_studio_domain_with_vpc.label == 'testcreate-domain' + assert sagemaker_studio_domain_with_vpc.vpcType == 'imported' + assert sagemaker_studio_domain_with_vpc.vpcId == 'vpc-12345' + assert sagemaker_studio_domain_with_vpc.subnetIds == ['subnet-12345', 'subnet-67890'] + assert sagemaker_studio_domain_with_vpc.environmentUri == env_fixture.environmentUri + + +def test_delete_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): + response = client.query( + """ + mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { + deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) + } + """, + environmentUri=env_fixture.environmentUri, + username='anonymoususer', + groups=[group2.name], + ) + + assert 'Unauthorized' in response.errors[0].message From 4eb4858a368926ecd27b9b454f55965fe883e152 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Fri, 1 Dec 2023 18:24:13 -0500 Subject: [PATCH 15/38] Clean Up --- .../dataall/base/cdkproxy/cdk.context.json | 96 ------------------- 1 file changed, 96 deletions(-) delete mode 100644 backend/dataall/base/cdkproxy/cdk.context.json diff --git a/backend/dataall/base/cdkproxy/cdk.context.json b/backend/dataall/base/cdkproxy/cdk.context.json deleted file mode 100644 index 6a02b4fe1..000000000 --- a/backend/dataall/base/cdkproxy/cdk.context.json +++ /dev/null @@ -1,96 +0,0 @@ -{ - "vpc-provider:account=139956106467:filter.isDefault=true:region=us-east-1:returnAsymmetricSubnets=true": { - "vpcId": "vpc-47a2473a", - "vpcCidrBlock": "172.31.0.0/16", - "ownerAccountId": "139956106467", - "availabilityZones": [], - "subnetGroups": [ - { - "name": "Public", - "type": "Public", - "subnets": [ - { - "subnetId": "subnet-ce854ca8", - "cidr": "172.31.0.0/20", - "availabilityZone": "us-east-1a", - "routeTableId": "rtb-eb234395" - }, - { - "subnetId": "subnet-dd2df9fc", - "cidr": "172.31.80.0/20", - "availabilityZone": "us-east-1b", - "routeTableId": "rtb-eb234395" - }, - { - "subnetId": "subnet-e357ceae", - "cidr": "172.31.16.0/20", - "availabilityZone": "us-east-1c", - "routeTableId": "rtb-eb234395" - }, - { - "subnetId": "subnet-9af53fc5", - "cidr": "172.31.32.0/20", - "availabilityZone": "us-east-1d", - "routeTableId": "rtb-eb234395" - }, - { - "subnetId": "subnet-95968bab", - "cidr": "172.31.48.0/20", - "availabilityZone": "us-east-1e", - "routeTableId": "rtb-eb234395" - }, - { - "subnetId": "subnet-6ba22165", - "cidr": "172.31.64.0/20", - "availabilityZone": "us-east-1f", - "routeTableId": "rtb-eb234395" - } - ] - } - ] - }, - "vpc-provider:account=139956106467:filter.vpc-id=vpc-09ddf78440e5c6d5d:region=us-east-1:returnAsymmetricSubnets=true": { - "vpcId": "vpc-09ddf78440e5c6d5d", - "vpcCidrBlock": "10.0.0.0/24", - "ownerAccountId": "139956106467", - "availabilityZones": [], - "subnetGroups": [ - { - "name": "Private", - "type": "Private", - "subnets": [ - { - "subnetId": "subnet-0f2c957fec49cc5b6", - "cidr": "10.0.0.128/28", - "availabilityZone": "us-east-1a", - "routeTableId": "rtb-0f5763c9bce96a6c3" - }, - { - "subnetId": "subnet-06d0ac5a5cdc3e842", - "cidr": "10.0.0.144/28", - "availabilityZone": "us-east-1b", - "routeTableId": "rtb-0e668f02b8963de94" - } - ] - }, - { - "name": "Public", - "type": "Public", - "subnets": [ - { - "subnetId": "subnet-0b83c25e072255092", - "cidr": "10.0.0.0/28", - "availabilityZone": "us-east-1a", - "routeTableId": "rtb-084b8fa8b6c24a230" - }, - { - "subnetId": "subnet-002ff94f2876021d5", - "cidr": "10.0.0.16/28", - "availabilityZone": "us-east-1b", - "routeTableId": "rtb-084b8fa8b6c24a230" - } - ] - } - ] - } -} From 2acfdfe03ab44f2c3985b7fac497bd61f8e7d8a9 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Mon, 4 Dec 2023 15:48:34 -0500 Subject: [PATCH 16/38] Clean up get studio domain and update migration script --- .../mlstudio/aws/sagemaker_studio_client.py | 20 ++-- .../mlstudio/cdk/mlstudio_extension.py | 3 - .../mlstudio/services/mlstudio_service.py | 12 +- ...f5de322f_update_sagemaker_studio_domain.py | 113 +++++++++--------- .../components/EnvironmentMLStudio.js | 14 ++- 5 files changed, 79 insertions(+), 83 deletions(-) diff --git a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py index 2a82806ea..2ee872b1c 100644 --- a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py +++ b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py @@ -12,28 +12,22 @@ def get_client(AwsAccountId, region): return session.client('sagemaker', region_name=region) -def get_sagemaker_studio_domain(AwsAccountId, region): +def get_sagemaker_studio_domain(AwsAccountId, region, domain_name): """ Sagemaker studio domain is limited to 5 per account/region RETURN: an existing domain or None if no domain is in the AWS account """ client = get_client(AwsAccountId=AwsAccountId, region=region) - existing_domain = dict() try: domain_id_paginator = client.get_paginator('list_domains') - domains = domain_id_paginator.paginate() - for _domain in domains: - print(_domain) - for _domain in _domain.get('Domains'): - # Get the domain name created by dataall - if 'dataall' in _domain: - return _domain - else: - existing_domain = _domain - return existing_domain + for page in domain_id_paginator.paginate(): + for domain in page.get('Domains', []): + if domain.get("DomainName") == domain_name: + return domain + return dict() except ClientError as e: print(e) - return 'NotFound' + return dict() class SagemakerStudioClient: diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 37131390a..9686b61cc 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -12,15 +12,12 @@ aws_ssm as ssm, RemovalPolicy, ) -from botocore.exceptions import ClientError from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository -from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.environment.aws.ec2_client import EC2 -from dataall.modules.mlstudio.aws.sagemaker_studio_client import get_sagemaker_studio_domain logger = logging.getLogger(__name__) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index ab3794201..8e8fa7e01 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -81,16 +81,13 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak action=CREATE_SGMSTUDIO_USER, message=f'ML Studio feature is disabled for the environment {env.label}', ) - # FOR OLD ONES + + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=env.environmentUri) response = get_sagemaker_studio_domain( AwsAccountId=env.AwsAccountId, - region=env.region + region=env.region, + domain_name=domain.sagemakerStudioDomainName ) - - # FOR NEW ONES (default, created, imported) - # - CHECK RDS FIRST - ONLY GET DOMAIN NAME - # - IF NOT BOTO3 - existing_domain = response.get('DomainId', False) if not existing_domain: @@ -204,7 +201,6 @@ def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn def delete_environment_sagemaker_studio_domain(*, uri: str): with _session() as session: domain = SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=uri) - # TODO: CHECK NUMBER OF USERS BEFORE DELETE session.delete(domain) return True diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 9c31cc7a2..9e05ab2aa 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -97,34 +97,61 @@ def upgrade(): "sagemaker_studio_domain", "environment", ["environmentUri"], ["environmentUri"], ) + else: + print("No sagemaker_studio_domain table found, creating...") + op.create_table( + 'sagemaker_studio_domain', + sa.Column('label', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('owner', sa.String(), nullable=False), + sa.Column('created', sa.DateTime(), nullable=True), + sa.Column('updated', sa.DateTime(), nullable=True), + sa.Column('deleted', sa.DateTime(), nullable=True), + sa.Column('description', sa.String(), nullable=True), + sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True), + sa.Column('environmentUri', sa.String(), nullable=False), + sa.Column('sagemakerStudioUri', sa.String(), nullable=False), + sa.Column('sagemakerStudioDomainID', sa.String(), nullable=True), + sa.Column('SagemakerStudioStatus', sa.String(), nullable=True), + sa.Column('AWSAccountId', sa.String(), nullable=False), + sa.Column('DefaultDomainRoleName', sa.String(), nullable=False), + sa.Column('sagemakerStudioDomainName', sa.String(), nullable=False), + sa.Column('vpcType', sa.String(), nullable=True), + sa.Column('vpcId', sa.String(), nullable=True), + sa.Column('subnetIds', postgresql.ARRAY(sa.String()), nullable=True), + sa.Column('region', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('sagemakerStudioUri'), + sa.ForeignKeyConstraint(columns=['environmentUri'], refcolumns=['environment.environmentUri']), + ) + + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" + ) + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown" + ) + session.add(domain) + session.flush() - print("Update sagemaker_studio_domain table done.") - print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") - - env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( - and_( - EnvironmentParameter.key == "mlStudiosEnabled", - EnvironmentParameter.value == "true" - ) - ).all() - for param in env_mlstudio_parameters: - env: Environment = session.query(Environment).filter( - Environment.environmentUri == param.environmentUri - ).first() - - domain = SagemakerStudioDomain( - label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - owner=env.owner, - description='No description provided', - environmentUri=env.environmentUri, - AWSAccountId=env.AwsAccountId, - region=env.region, - DefaultDomainRoleName="RoleSagemakerStudioUsers", - sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - vpcType="unknown" - ) - session.add(domain) - session.flush() session.commit() print("Fill of sagemaker_studio_domain table is done") @@ -142,36 +169,10 @@ def downgrade(): session = orm.Session(bind=bind) if has_table('sagemaker_studio_domain', engine): - print("Updating of sagemaker_studio_domain table...") - op.alter_column( - 'sagemaker_studio_domain', - 'sagemakerStudioDomainID', - nullable=False, - existing_type=sa.String() - ) - op.alter_column( - 'sagemaker_studio_domain', - 'SagemakerStudioStatus', - nullable=False, - existing_type=sa.String() - ) - op.alter_column( - 'sagemaker_studio_domain', - 'DefaultDomainRoleName', - new_column_name='RoleArn', - nullable=False, - existing_type=sa.String() - ) - - op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName") - op.drop_column("sagemaker_studio_domain", "vpcType") - op.drop_column("sagemaker_studio_domain", "vpcId") - op.drop_column("sagemaker_studio_domain", "subnetIds") - - op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") - + print("Dropping sagemaker_studio_domain table...") + op.drop_table("sagemaker_studio_domain") session.commit() - print("Update of sagemaker_studio_domain table is done") + print("Dropped of sagemaker_studio_domain table") except Exception as exception: print('Failed to downgrade due to:', exception) diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 158adf3d9..270bc414e 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -6,7 +6,8 @@ import { Grid, CardContent, Typography, - CircularProgress + CircularProgress, + Chip } from '@mui/material'; import PropTypes from 'prop-types'; @@ -106,7 +107,7 @@ export const EnvironmentMLStudio = ({ environment }) => { SageMaker ML Studio Default Execution Role - arn:aws:s3::: + arn:aws:iam::{environment.AwsAccountId}:role/ {mlStudioDomain.DefaultDomainRoleName} @@ -133,7 +134,14 @@ export const EnvironmentMLStudio = ({ environment }) => { Domain Subnet Ids - {mlStudioDomain.subnetIds} + {mlStudioDomain.subnetIds?.map((subnet) => ( + + ))} From 4215ff43a0606b8997dbc5e74697f31e62a3fcc1 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Mon, 4 Dec 2023 15:52:17 -0500 Subject: [PATCH 17/38] Edit text when ML Studio disabled --- .../modules/Environments/components/EnvironmentMLStudio.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 270bc414e..7f367ecb2 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -84,8 +84,8 @@ export const EnvironmentMLStudio = ({ environment }) => { variant="subtitle2" > No ML Studio Domain - To Create a ML Studio Domain for this - Environment: `{environment.label}`, edit the Environment and - enable the ML Studio Environment Feature + Environment: {environment.label}, edit the Environment and enable + the ML Studio Environment Feature ) : ( From 95db11b02d19c957b532549a1dd29fadbaf95d6c Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Mon, 4 Dec 2023 16:44:54 -0500 Subject: [PATCH 18/38] Fix coverage tests --- tests/modules/mlstudio/conftest.py | 108 +++++++++--------- .../modules/mlstudio/test_sagemaker_studio.py | 73 ++++++------ 2 files changed, 87 insertions(+), 94 deletions(-) diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 8f105be36..f37dbf886 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -40,7 +40,60 @@ def check_vpc_exists(module_mocker): @pytest.fixture(scope='module') -def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudioUser: +def sagemaker_studio_domain(client, group, env_fixture) -> SagemakerStudioDomain: + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + }, + username='alice', + groups=[group.name], + ) + yield response.data.createMLStudioDomain + + +@pytest.fixture(scope='module') +def sagemaker_studio_domain_with_vpc(client, group, env_fixture) -> SagemakerStudioDomain: + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + 'vpcId': 'vpc-12345', + 'subnetIds': ['subnet-12345', 'subnet-67890'] + }, + username='alice', + groups=[group.name], + ) + + yield response.data.createMLStudioDomain + + +@pytest.fixture(scope='module') +def sagemaker_studio_user(client, tenant, group, env_fixture, sagemaker_studio_domain) -> SagemakerStudioUser: response = client.query( """ mutation createSagemakerStudioUser($input:NewSagemakerStudioUserInput){ @@ -102,56 +155,3 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): response.data.createSagemakerStudioUser.environmentUri == env_fixture.environmentUri ) - - -@pytest.fixture(scope='module') -def sagemaker_studio_domain(client, group, env_fixture) -> SagemakerStudioDomain: - response = client.query( - """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - }, - username='alice', - groups=[group.name], - ) - yield response.data.createMLStudioDomain - - -@pytest.fixture(scope='module') -def sagemaker_studio_domain_with_vpc(client, group, env_fixture) -> SagemakerStudioDomain: - response = client.query( - """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - 'vpcId': 'vpc-12345', - 'subnetIds': ['subnet-12345', 'subnet-67890'] - }, - username='alice', - groups=[group.name], - ) - - yield response.data.createMLStudioDomain \ No newline at end of file diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py index bbdf80796..87ef441b3 100644 --- a/tests/modules/mlstudio/test_sagemaker_studio.py +++ b/tests/modules/mlstudio/test_sagemaker_studio.py @@ -1,6 +1,39 @@ from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +def test_create_sagemaker_studio_domain(sagemaker_studio_domain, env_fixture): + """Testing that the conftest sagemaker studio domain has been created correctly""" + assert sagemaker_studio_domain.label == 'testcreate-domain' + assert sagemaker_studio_domain.vpcType == 'created' + assert sagemaker_studio_domain.vpcId is None + assert len(sagemaker_studio_domain.subnetIds) == 0 + assert sagemaker_studio_domain.environmentUri == env_fixture.environmentUri + + +def test_create_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): + response = client.query( + """ + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + """, + input={ + 'label': 'testcreate', + 'environmentUri': env_fixture.environmentUri, + }, + username='anonymoususer', + groups=[group2.name], + ) + assert 'Unauthorized' in response.errors[0].message + + def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_fixture): """Testing that the conftest sagemaker studio user has been created correctly""" assert sagemaker_studio_user.label == 'testcreate' @@ -68,46 +101,6 @@ def test_delete_sagemaker_studio_user( ) assert not n -# @pytest.fixture -# def mock_check_domain_vpc(mocker): -# return mocker.patch( -# 'dataall.modules.mlstudio.services.mlstudio_service.SagemakerStudioService.check_mlstudio_domain_vpc', -# return_value=False, -# ) - - -def test_create_sagemaker_studio_domain(sagemaker_studio_domain, env_fixture): - """Testing that the conftest sagemaker studio domain has been created correctly""" - assert sagemaker_studio_domain.label == 'testcreate-domain' - assert sagemaker_studio_domain.vpcType == 'created' - assert sagemaker_studio_domain.vpcId is None - assert len(sagemaker_studio_domain.subnetIds) == 0 - assert sagemaker_studio_domain.environmentUri == env_fixture.environmentUri - - -def test_create_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): - response = client.query( - """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - }, - username='anonymoususer', - groups=[group2.name], - ) - assert 'Unauthorized' in response.errors[0].message - def test_get_sagemaker_studio_domain(client, env_fixture, sagemaker_studio_domain): response = client.query( From 2dc0e1a11c011cb1ed925231784427bce2038628 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Mon, 4 Dec 2023 17:10:45 -0500 Subject: [PATCH 19/38] Revert migration script --- ...f5de322f_update_sagemaker_studio_domain.py | 113 +++++++++--------- 1 file changed, 56 insertions(+), 57 deletions(-) diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 9e05ab2aa..9c31cc7a2 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -97,61 +97,34 @@ def upgrade(): "sagemaker_studio_domain", "environment", ["environmentUri"], ["environmentUri"], ) - else: - print("No sagemaker_studio_domain table found, creating...") - op.create_table( - 'sagemaker_studio_domain', - sa.Column('label', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('owner', sa.String(), nullable=False), - sa.Column('created', sa.DateTime(), nullable=True), - sa.Column('updated', sa.DateTime(), nullable=True), - sa.Column('deleted', sa.DateTime(), nullable=True), - sa.Column('description', sa.String(), nullable=True), - sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True), - sa.Column('environmentUri', sa.String(), nullable=False), - sa.Column('sagemakerStudioUri', sa.String(), nullable=False), - sa.Column('sagemakerStudioDomainID', sa.String(), nullable=True), - sa.Column('SagemakerStudioStatus', sa.String(), nullable=True), - sa.Column('AWSAccountId', sa.String(), nullable=False), - sa.Column('DefaultDomainRoleName', sa.String(), nullable=False), - sa.Column('sagemakerStudioDomainName', sa.String(), nullable=False), - sa.Column('vpcType', sa.String(), nullable=True), - sa.Column('vpcId', sa.String(), nullable=True), - sa.Column('subnetIds', postgresql.ARRAY(sa.String()), nullable=True), - sa.Column('region', sa.String(), nullable=True), - sa.PrimaryKeyConstraint('sagemakerStudioUri'), - sa.ForeignKeyConstraint(columns=['environmentUri'], refcolumns=['environment.environmentUri']), - ) - - print("Update sagemaker_studio_domain table done.") - print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") - - env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( - and_( - EnvironmentParameter.key == "mlStudiosEnabled", - EnvironmentParameter.value == "true" - ) - ).all() - for param in env_mlstudio_parameters: - env: Environment = session.query(Environment).filter( - Environment.environmentUri == param.environmentUri - ).first() - - domain = SagemakerStudioDomain( - label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - owner=env.owner, - description='No description provided', - environmentUri=env.environmentUri, - AWSAccountId=env.AwsAccountId, - region=env.region, - DefaultDomainRoleName="RoleSagemakerStudioUsers", - sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - vpcType="unknown" - ) - session.add(domain) - session.flush() + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" + ) + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown" + ) + session.add(domain) + session.flush() session.commit() print("Fill of sagemaker_studio_domain table is done") @@ -169,10 +142,36 @@ def downgrade(): session = orm.Session(bind=bind) if has_table('sagemaker_studio_domain', engine): - print("Dropping sagemaker_studio_domain table...") - op.drop_table("sagemaker_studio_domain") + print("Updating of sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + nullable=False, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + nullable=False, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'DefaultDomainRoleName', + new_column_name='RoleArn', + nullable=False, + existing_type=sa.String() + ) + + op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName") + op.drop_column("sagemaker_studio_domain", "vpcType") + op.drop_column("sagemaker_studio_domain", "vpcId") + op.drop_column("sagemaker_studio_domain", "subnetIds") + + op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") + session.commit() - print("Dropped of sagemaker_studio_domain table") + print("Update of sagemaker_studio_domain table is done") except Exception as exception: print('Failed to downgrade due to:', exception) From 5cb1d8b0158c6df89a94143ec56ea85e51163400 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Tue, 5 Dec 2023 09:58:46 -0500 Subject: [PATCH 20/38] Handle null values Edit form --- .../src/modules/Environments/views/EnvironmentEditForm.js | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index 3d38b2196..5eca5e936 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -89,11 +89,13 @@ const EnvironmentEditForm = (props) => { } setLoading(false); }, [client, dispatch, params.uri]); + useEffect(() => { if (client) { fetchItem().catch((e) => dispatch({ type: SET_ERROR, error: e.message })); } }, [client, fetchItem, dispatch]); + async function submit(values, setStatus, setSubmitting, setErrors) { try { const response = await client.mutate( @@ -148,7 +150,6 @@ const EnvironmentEditForm = (props) => { values.mlStudiosEnabled !== previousEnvMLStudioEnabled && values.mlStudiosEnabled === false ) { - console.error(envMLStudioDomain.sagemakerStudioUri); const response2 = await client.mutate( deleteEnvironmentMLStudioDomain({ environmentUri: envMLStudioDomain.environmentUri @@ -268,8 +269,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], - mlStudioVPCId: envMLStudioDomain.vpcId, - mlStudioSubnetIds: envMLStudioDomain.subnetIds, + mlStudioVPCId: envMLStudioDomain.vpcId || '', + mlStudioSubnetIds: envMLStudioDomain.subnetIds || [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', From 32bc04d56f375a4af13d2e0b4fdbba3d206dfbe3 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 10:32:25 -0500 Subject: [PATCH 21/38] Move EC2 to base, clean up unused code, move APIs to shared, add delete Domain on delete env --- .../dataall/core/environment/api/resolvers.py | 2 +- .../dataall/core/environment/aws/__init__.py | 0 .../core/environment/aws/ec2_client.py | 54 ------------------- .../mlstudio/cdk/mlstudio_extension.py | 2 +- .../mlstudio/db/mlstudio_repositories.py | 19 ------- .../mlstudio/services/mlstudio_service.py | 8 +-- .../components/EnvironmentMLStudio.js | 3 +- .../services/createMLStudioDomain.js | 19 ------- .../deleteEnvironmentMLStudioDomain.js | 12 ----- .../services/getEnvironmentMLStudioDomain.js | 23 -------- .../modules/Environments/services/index.js | 3 -- .../views/EnvironmentCreateForm.js | 2 +- .../Environments/views/EnvironmentEditForm.js | 9 ++-- .../Environments/views/EnvironmentView.js | 12 ++++- frontend/src/services/graphql/index.js | 1 + 15 files changed, 24 insertions(+), 145 deletions(-) delete mode 100644 backend/dataall/core/environment/aws/__init__.py delete mode 100644 backend/dataall/core/environment/aws/ec2_client.py delete mode 100644 frontend/src/modules/Environments/services/createMLStudioDomain.js delete mode 100644 frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js delete mode 100644 frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 38bbbd32a..c050d228a 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -20,7 +20,7 @@ from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.db.stack_repositories import Stack from dataall.core.vpc.db.vpc_repositories import Vpc -from dataall.core.environment.aws.ec2_client import EC2 +from dataall.backend.dataall.base.aws.ec2_client import EC2 from dataall.base.db import exceptions from dataall.core.permissions import permissions from dataall.base.feature_toggle_checker import is_feature_enabled diff --git a/backend/dataall/core/environment/aws/__init__.py b/backend/dataall/core/environment/aws/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/dataall/core/environment/aws/ec2_client.py b/backend/dataall/core/environment/aws/ec2_client.py deleted file mode 100644 index 23d290a7e..000000000 --- a/backend/dataall/core/environment/aws/ec2_client.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging - -from dataall.base.aws.sts import SessionHelper -from botocore.exceptions import ClientError - -log = logging.getLogger(__name__) - - -class EC2: - - @staticmethod - def get_client(account_id: str, region: str, role=None): - session = SessionHelper.remote_session(accountid=account_id, role=role) - return session.client('ec2', region_name=region) - - @staticmethod - def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): - log.info("Check that default VPC exists..") - client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) - response = client.describe_vpcs( - Filters=[{'Name': 'isDefault', 'Values': ['true']}] - ) - vpcs = response['Vpcs'] - log.info(f"Default VPCs response: {vpcs}") - if vpcs: - return True - return False - - @staticmethod - def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): - try: - ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role) - response = ec2.describe_vpcs(VpcIds=[vpc_id]) - except ClientError as e: - log.exception(f'VPC Id {vpc_id} Not Found: {e}') - raise Exception(f'VPCNotFound: {vpc_id}') - - try: - if subnet_ids: - response = ec2.describe_subnets( - Filters=[ - { - 'Name': 'vpc-id', - 'Values': [vpc_id] - }, - ], - SubnetIds=subnet_ids - ) - except ClientError as e: - log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') - raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') - - if not subnet_ids or len(response['Subnets']) != len(subnet_ids): - raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 9686b61cc..d99ee47b3 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -17,7 +17,7 @@ from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.core.environment.aws.ec2_client import EC2 +from dataall.backend.dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 49d44c2c9..875fd95ce 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -2,7 +2,6 @@ DAO layer that encapsulates the logic and interaction with the database for ML Studio Provides the API to retrieve / update / delete ml studio """ -import stat from typing import Optional from sqlalchemy import or_ from sqlalchemy.sql import and_ @@ -121,24 +120,6 @@ def create_sagemaker_studio_domain(session, username, environment, data): return domain - @staticmethod - def _query_environment_sagemaker_studio_domains(session, uri, filter) -> Query: - query = session.query(SagemakerStudioDomain).filter( - SagemakerStudioDomain.environmentUri == uri, - ) - if filter and filter.get('term'): - query = query.filter( - or_( - SagemakerStudioDomain.description.ilike( - filter.get('term') + '%%' - ), - SagemakerStudioDomain.label.ilike( - filter.get('term') + '%%' - ), - ) - ) - return query - @staticmethod def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 8e8fa7e01..b75788ae6 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -19,7 +19,7 @@ from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -from dataall.core.environment.aws.ec2_client import EC2 +from dataall.backend.dataall.base.aws.ec2_client import EC2 from dataall.base.aws.sts import SessionHelper from dataall.modules.mlstudio.services.mlstudio_permissions import ( @@ -93,8 +93,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak if not existing_domain: raise exceptions.AWSResourceNotAvailable( action='Sagemaker Studio domain', - message='Update the environment stack ' - 'or create a Sagemaker studio domain on your AWS account.', + message='Update the environment stack and enable ML Studio Environment Feature' ) sagemaker_studio_user = SagemakerStudioUser( @@ -201,7 +200,8 @@ def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn def delete_environment_sagemaker_studio_domain(*, uri: str): with _session() as session: domain = SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=uri) - session.delete(domain) + if domain: + session.delete(domain) return True @staticmethod diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 7f367ecb2..744a0350c 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -14,8 +14,7 @@ import PropTypes from 'prop-types'; import React, { useCallback, useEffect, useState } from 'react'; import { RefreshTableMenu, ObjectMetadata } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; -import { getEnvironmentMLStudioDomain } from '../services'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; export const EnvironmentMLStudio = ({ environment }) => { const client = useClient(); diff --git a/frontend/src/modules/Environments/services/createMLStudioDomain.js b/frontend/src/modules/Environments/services/createMLStudioDomain.js deleted file mode 100644 index 3940e6748..000000000 --- a/frontend/src/modules/Environments/services/createMLStudioDomain.js +++ /dev/null @@ -1,19 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const createMLStudioDomain = (input) => ({ - variables: { - input - }, - mutation: gql` - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - ` -}); diff --git a/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js b/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js deleted file mode 100644 index 7abdc7e9e..000000000 --- a/frontend/src/modules/Environments/services/deleteEnvironmentMLStudioDomain.js +++ /dev/null @@ -1,12 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const deleteEnvironmentMLStudioDomain = ({ environmentUri }) => ({ - variables: { - environmentUri - }, - mutation: gql` - mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { - deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) - } - ` -}); diff --git a/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js b/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js deleted file mode 100644 index 9dc34d630..000000000 --- a/frontend/src/modules/Environments/services/getEnvironmentMLStudioDomain.js +++ /dev/null @@ -1,23 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const getEnvironmentMLStudioDomain = ({ environmentUri }) => ({ - variables: { - environmentUri - }, - query: gql` - query getEnvironmentMLStudioDomain($environmentUri: String) { - getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { - sagemakerStudioUri - environmentUri - label - sagemakerStudioDomainName - DefaultDomainRoleName - vpcType - vpcId - subnetIds - owner - created - } - } - ` -}); diff --git a/frontend/src/modules/Environments/services/index.js b/frontend/src/modules/Environments/services/index.js index 8955122b0..14f5b659f 100644 --- a/frontend/src/modules/Environments/services/index.js +++ b/frontend/src/modules/Environments/services/index.js @@ -22,6 +22,3 @@ export * from './removeConsumptionRole'; export * from './removeGroup'; export * from './updateEnvironment'; export * from './updateGroupEnvironmentPermissions'; -export * from './createMLStudioDomain'; -export * from './deleteEnvironmentMLStudioDomain'; -export * from './getEnvironmentMLStudioDomain'; diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 467a2cd42..4954d1076 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -32,7 +32,6 @@ import { Helmet } from 'react-helmet-async'; import { Link as RouterLink, useNavigate, useParams } from 'react-router-dom'; import * as Yup from 'yup'; import { - createMLStudioDomain, createEnvironment, getPivotRoleExternalId, getPivotRoleName, @@ -47,6 +46,7 @@ import { } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { + createMLStudioDomain, getOrganization, getTrustAccount, useClient, diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index 5eca5e936..75eec8497 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -30,14 +30,13 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; import { - getEnvironment, - updateEnvironment, getEnvironmentMLStudioDomain, createMLStudioDomain, - deleteEnvironmentMLStudioDomain -} from '../services'; + deleteEnvironmentMLStudioDomain, + useClient +} from 'services'; +import { getEnvironment, updateEnvironment } from '../services'; import { isAnyEnvironmentModuleEnabled, isModuleEnabled, diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 792918c13..87095af78 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -34,7 +34,7 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; +import { deleteEnvironmentMLStudioDomain, useClient } from 'services'; import { archiveEnvironment, getEnvironment } from '../services'; import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared'; import { @@ -111,6 +111,16 @@ const EnvironmentView = () => { }) ); if (!response.errors) { + if (isModuleEnabled(ModuleNames.MLSTUDIO)) { + const response2 = await client.mutate( + deleteEnvironmentMLStudioDomain({ + environmentUri: env.environmentUri + }) + ); + if (response2.errors) { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } handleArchiveObjectModalClose(); enqueueSnackbar('Environment deleted', { anchorOrigin: { diff --git a/frontend/src/services/graphql/index.js b/frontend/src/services/graphql/index.js index 8d0e00804..ce1c3fba2 100644 --- a/frontend/src/services/graphql/index.js +++ b/frontend/src/services/graphql/index.js @@ -8,6 +8,7 @@ export * from './Glossary'; export * from './Groups'; export * from './KeyValueTags'; export * from './Metric'; +export * from './MLStudio'; export * from './Notification'; export * from './Organization'; export * from './Principal'; From f9e522d652949e2eaa692b2cea5f65e0b1ed998a Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 10:32:31 -0500 Subject: [PATCH 22/38] Move EC2 to base, clean up unused code, move APIs to shared, add delete Domain on delete env --- .../graphql/MLStudio/createMLStudioDomain.js | 19 +++++++++++++++ .../deleteEnvironmentMLStudioDomain.js | 12 ++++++++++ .../MLStudio/getEnvironmentMLStudioDomain.js | 23 +++++++++++++++++++ .../src/services/graphql/MLStudio/index.js | 3 +++ 4 files changed, 57 insertions(+) create mode 100644 frontend/src/services/graphql/MLStudio/createMLStudioDomain.js create mode 100644 frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js create mode 100644 frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js create mode 100644 frontend/src/services/graphql/MLStudio/index.js diff --git a/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js b/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js new file mode 100644 index 000000000..3940e6748 --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js @@ -0,0 +1,19 @@ +import { gql } from 'apollo-boost'; + +export const createMLStudioDomain = (input) => ({ + variables: { + input + }, + mutation: gql` + mutation createMLStudioDomain($input: NewStudioDomainInput) { + createMLStudioDomain(input: $input) { + sagemakerStudioUri + environmentUri + label + vpcType + vpcId + subnetIds + } + } + ` +}); diff --git a/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js b/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js new file mode 100644 index 000000000..7abdc7e9e --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js @@ -0,0 +1,12 @@ +import { gql } from 'apollo-boost'; + +export const deleteEnvironmentMLStudioDomain = ({ environmentUri }) => ({ + variables: { + environmentUri + }, + mutation: gql` + mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { + deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) + } + ` +}); diff --git a/frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js b/frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js new file mode 100644 index 000000000..9dc34d630 --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js @@ -0,0 +1,23 @@ +import { gql } from 'apollo-boost'; + +export const getEnvironmentMLStudioDomain = ({ environmentUri }) => ({ + variables: { + environmentUri + }, + query: gql` + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + ` +}); diff --git a/frontend/src/services/graphql/MLStudio/index.js b/frontend/src/services/graphql/MLStudio/index.js new file mode 100644 index 000000000..5ae789bc8 --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/index.js @@ -0,0 +1,3 @@ +export * from './createMLStudioDomain'; +export * from './deleteEnvironmentMLStudioDomain'; +export * from './getEnvironmentMLStudioDomain'; From cec2afe0caba3031815025c548cf65a47fb6348d Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 10:33:14 -0500 Subject: [PATCH 23/38] Move EC2 to base, clean up unused code, move APIs to shared, add delete Domain on delete env --- backend/dataall/base/aws/ec2_client.py | 54 ++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 backend/dataall/base/aws/ec2_client.py diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py new file mode 100644 index 000000000..23d290a7e --- /dev/null +++ b/backend/dataall/base/aws/ec2_client.py @@ -0,0 +1,54 @@ +import logging + +from dataall.base.aws.sts import SessionHelper +from botocore.exceptions import ClientError + +log = logging.getLogger(__name__) + + +class EC2: + + @staticmethod + def get_client(account_id: str, region: str, role=None): + session = SessionHelper.remote_session(accountid=account_id, role=role) + return session.client('ec2', region_name=region) + + @staticmethod + def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): + log.info("Check that default VPC exists..") + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_vpcs( + Filters=[{'Name': 'isDefault', 'Values': ['true']}] + ) + vpcs = response['Vpcs'] + log.info(f"Default VPCs response: {vpcs}") + if vpcs: + return True + return False + + @staticmethod + def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + try: + ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = ec2.describe_vpcs(VpcIds=[vpc_id]) + except ClientError as e: + log.exception(f'VPC Id {vpc_id} Not Found: {e}') + raise Exception(f'VPCNotFound: {vpc_id}') + + try: + if subnet_ids: + response = ec2.describe_subnets( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [vpc_id] + }, + ], + SubnetIds=subnet_ids + ) + except ClientError as e: + log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') + raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') + + if not subnet_ids or len(response['Subnets']) != len(subnet_ids): + raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') From 47241c62c2f431b2e06d22c160dc0a705f427e47 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 10:57:12 -0500 Subject: [PATCH 24/38] fix import paths --- backend/dataall/core/environment/api/resolvers.py | 2 +- backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py | 2 +- backend/dataall/modules/mlstudio/services/mlstudio_service.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index c050d228a..ffa6b0dc6 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -20,7 +20,7 @@ from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.db.stack_repositories import Stack from dataall.core.vpc.db.vpc_repositories import Vpc -from dataall.backend.dataall.base.aws.ec2_client import EC2 +from dataall.base.aws.ec2_client import EC2 from dataall.base.db import exceptions from dataall.core.permissions import permissions from dataall.base.feature_toggle_checker import is_feature_enabled diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index d99ee47b3..1a0da209b 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -17,7 +17,7 @@ from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.backend.dataall.base.aws.ec2_client import EC2 +from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index b75788ae6..12e70ca29 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -19,7 +19,7 @@ from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -from dataall.backend.dataall.base.aws.ec2_client import EC2 +from dataall.base.aws.ec2_client import EC2 from dataall.base.aws.sts import SessionHelper from dataall.modules.mlstudio.services.mlstudio_permissions import ( From c50e34573666f3d0d6724b26f76a257303a781c2 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 11:14:52 -0500 Subject: [PATCH 25/38] Fix tests patch --- tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py | 2 +- tests/modules/mlstudio/conftest.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py index 4ea1f34fa..8e0cd6166 100644 --- a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py +++ b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py @@ -72,7 +72,7 @@ def patch_methods_sagemaker_studio_extension(mocker, sgm_studio_domain): return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", ) mocker.patch( - 'dataall.core.environment.aws.ec2_client.EC2.check_default_vpc_exists', + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', return_value=False, ) mocker.patch( diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index f37dbf886..4b91e24cd 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -26,7 +26,7 @@ def get_cdk_look_up_role_arn(module_mocker): @pytest.fixture(scope='module', autouse=True) def check_default_vpc(module_mocker): module_mocker.patch( - 'dataall.core.environment.aws.ec2_client.EC2.check_default_vpc_exists', + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', return_value=False, ) @@ -34,7 +34,7 @@ def check_default_vpc(module_mocker): @pytest.fixture(scope='module', autouse=True) def check_vpc_exists(module_mocker): module_mocker.patch( - 'dataall.core.environment.aws.ec2_client.EC2.check_vpc_exists', + 'dataall.base.aws.ec2_client.EC2.check_vpc_exists', return_value=True, ) From f2a8350271469f1a9c2471ef7cf16d9969fe52b7 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 11:54:05 -0500 Subject: [PATCH 26/38] delete domain before deleting env --- .../Environments/views/EnvironmentView.js | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 87095af78..7e07a0023 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -104,6 +104,16 @@ const EnvironmentView = () => { }; const archiveEnv = async () => { + if (isModuleEnabled(ModuleNames.MLSTUDIO)) { + const response2 = await client.mutate( + deleteEnvironmentMLStudioDomain({ + environmentUri: env.environmentUri + }) + ); + if (response2.errors) { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } const response = await client.mutate( archiveEnvironment({ environmentUri: env.environmentUri, @@ -111,16 +121,6 @@ const EnvironmentView = () => { }) ); if (!response.errors) { - if (isModuleEnabled(ModuleNames.MLSTUDIO)) { - const response2 = await client.mutate( - deleteEnvironmentMLStudioDomain({ - environmentUri: env.environmentUri - }) - ); - if (response2.errors) { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } handleArchiveObjectModalClose(); enqueueSnackbar('Environment deleted', { anchorOrigin: { From e5cc43fe9b07c86a1a06f405701ec195bb7ffca9 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 12:16:21 -0500 Subject: [PATCH 27/38] Add delete ml studio domain as part of delete env backend env resource --- .../modules/mlstudio/db/mlstudio_repositories.py | 7 +++++++ .../modules/Environments/views/EnvironmentView.js | 12 +----------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 875fd95ce..e6700d871 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -30,6 +30,13 @@ def update_env(session, environment): previous_mlstudio_enabled = True if domain else False return current_mlstudio_enabled != previous_mlstudio_enabled + @staticmethod + def delete_env(session, environment): + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment.environment_uri) + if domain: + session.delete(domain) + return True + @staticmethod def save_sagemaker_studio_user(session, user): """Save SageMaker Studio user to the database""" diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 7e07a0023..792918c13 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -34,7 +34,7 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { deleteEnvironmentMLStudioDomain, useClient } from 'services'; +import { useClient } from 'services'; import { archiveEnvironment, getEnvironment } from '../services'; import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared'; import { @@ -104,16 +104,6 @@ const EnvironmentView = () => { }; const archiveEnv = async () => { - if (isModuleEnabled(ModuleNames.MLSTUDIO)) { - const response2 = await client.mutate( - deleteEnvironmentMLStudioDomain({ - environmentUri: env.environmentUri - }) - ); - if (response2.errors) { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } const response = await client.mutate( archiveEnvironment({ environmentUri: env.environmentUri, From 3aa03b32d013ac093c56848e69e109fa495870a4 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 12:38:30 -0500 Subject: [PATCH 28/38] Fix delete_env --- backend/dataall/modules/mlstudio/db/mlstudio_repositories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index e6700d871..a124e164d 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -32,7 +32,7 @@ def update_env(session, environment): @staticmethod def delete_env(session, environment): - domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment.environment_uri) + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment.environmentUri) if domain: session.delete(domain) return True From 57f0e8b7c78d2bcfb218092618a7be4b1d5a2c1c Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 17:41:53 -0500 Subject: [PATCH 29/38] Change method of create / delete ml studio to not call 2 APIs on environment create / delete / update --- .../core/environment/api/input_types.py | 8 +- .../dataall/core/environment/api/resolvers.py | 8 +- .../services/environment_resource_manager.py | 15 +- .../services/environment_service.py | 1 + .../dashboards/db/dashboard_repositories.py | 2 +- backend/dataall/modules/mlstudio/__init__.py | 4 +- .../modules/mlstudio/api/input_types.py | 19 -- .../dataall/modules/mlstudio/api/mutations.py | 23 -- .../dataall/modules/mlstudio/api/resolvers.py | 30 --- backend/dataall/modules/mlstudio/api/types.py | 12 -- .../mlstudio/db/mlstudio_repositories.py | 26 +-- .../mlstudio/services/mlstudio_service.py | 111 +++++----- .../views/EnvironmentCreateForm.js | 46 ++-- .../Environments/views/EnvironmentEditForm.js | 72 ++----- .../graphql/MLStudio/createMLStudioDomain.js | 19 -- .../deleteEnvironmentMLStudioDomain.js | 12 -- .../src/services/graphql/MLStudio/index.js | 2 - tests/modules/mlstudio/conftest.py | 200 +++++++++++++----- .../modules/mlstudio/test_sagemaker_studio.py | 158 ++++++++------ 19 files changed, 344 insertions(+), 424 deletions(-) delete mode 100644 frontend/src/services/graphql/MLStudio/createMLStudioDomain.js delete mode 100644 frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 15786955b..27188f4ed 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -31,8 +31,8 @@ gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), - gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)) + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) @@ -45,8 +45,8 @@ gql.Argument('SamlGroupName', gql.String), gql.Argument('resourcePrefix', gql.String), gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), - gql.Argument('mlStudioVPCId', gql.String), - gql.Argument('mlStudioSubnetIds', gql.ArrayType(gql.String)) + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index ffa6b0dc6..06878cdfc 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -77,14 +77,14 @@ def check_environment(context: Context, source, account_id, region, data): if parameter['key'] == 'mlStudiosEnabled': mlStudioEnabled = parameter['value'] - if mlStudioEnabled and data.get("mlStudioVPCId", None) and data.get("mlStudioSubnetIds", []): + if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []): log.info("Check if ML Studio VPC Exists in the Account") EC2.check_vpc_exists( AwsAccountId=account_id, region=region, role=cdk_look_up_role_arn, - vpc_id=data.get("mlStudioVPCId", None), - subnet_ids=data.get('mlStudioSubnetIds', []), + vpc_id=data.get("vpcId", None), + subnet_ids=data.get('subnetIds', []), ) return cdk_role_name @@ -148,7 +148,7 @@ def update_environment( data=input, ) - if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment): + if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input): stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment diff --git a/backend/dataall/core/environment/services/environment_resource_manager.py b/backend/dataall/core/environment/services/environment_resource_manager.py index bc74f01bf..f5c2551fa 100644 --- a/backend/dataall/core/environment/services/environment_resource_manager.py +++ b/backend/dataall/core/environment/services/environment_resource_manager.py @@ -12,7 +12,11 @@ def delete_env(session, environment): pass @staticmethod - def update_env(session, environment): + def create_env(session, environment, **kwargs): + pass + + @staticmethod + def update_env(session, environment, **kwargs): return False @staticmethod @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int: return counter @classmethod - def deploy_updated_stack(cls, session, prev_prefix, environment): + def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs): deploy_stack = prev_prefix != environment.resourcePrefix for resource in cls._resources: - deploy_stack |= resource.update_env(session, environment) + deploy_stack |= resource.update_env(session, environment, **kwargs) return deploy_stack @@ -51,6 +55,11 @@ def delete_env(cls, session, environment): for resource in cls._resources: resource.delete_env(session, environment) + @classmethod + def create_env(cls, session, environment, **kwargs): + for resource in cls._resources: + resource.create_env(session, environment, **kwargs) + @classmethod def count_consumption_role_resources(cls, session, role_uri): counter = 0 diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index 7c718631f..1b2dbec07 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None): session.commit() EnvironmentService._update_env_parameters(session, env, data) + EnvironmentResourceManager.create_env(session, env, data=data) env.EnvironmentDefaultBucketName = NamingConventionService( target_uri=env.environmentUri, diff --git a/backend/dataall/modules/dashboards/db/dashboard_repositories.py b/backend/dataall/modules/dashboards/db/dashboard_repositories.py index 91916f8ff..a8d9d6a2f 100644 --- a/backend/dataall/modules/dashboards/db/dashboard_repositories.py +++ b/backend/dataall/modules/dashboards/db/dashboard_repositories.py @@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int: ) @staticmethod - def update_env(session, environment): + def update_env(session, environment, **kwargs): return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled") @staticmethod diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py index 190267430..a6ca73917 100644 --- a/backend/dataall/modules/mlstudio/__init__.py +++ b/backend/dataall/modules/mlstudio/__init__.py @@ -3,7 +3,7 @@ from dataall.base.loader import ImportMode, ModuleInterface from dataall.core.stacks.db.target_type_repositories import TargetType -from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager log = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def __init__(self): from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER) - EnvironmentResourceManager.register(SageMakerStudioRepository()) + EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource()) log.info("API of sagemaker mlstudio has been imported") diff --git a/backend/dataall/modules/mlstudio/api/input_types.py b/backend/dataall/modules/mlstudio/api/input_types.py index e19c3eb1c..f05fd53f6 100644 --- a/backend/dataall/modules/mlstudio/api/input_types.py +++ b/backend/dataall/modules/mlstudio/api/input_types.py @@ -33,22 +33,3 @@ gql.Argument('offset', gql.Integer), ], ) - -SagemakerStudioDomainFilter = gql.InputType( - name='SagemakerStudioDomainFilter', - arguments=[ - gql.Argument('term', gql.String), - gql.Argument(name='page', type=gql.Integer), - gql.Argument(name='pageSize', type=gql.Integer), - ], -) - -NewStudioDomainInput = gql.InputType( - name='NewStudioDomainInput', - arguments=[ - gql.Argument('label', gql.NonNullableType(gql.String)), - gql.Argument('environmentUri', gql.NonNullableType(gql.String)), - gql.Argument('subnetIds', gql.ArrayType(gql.String)), - gql.Argument('vpcId', gql.String), - ], -) diff --git a/backend/dataall/modules/mlstudio/api/mutations.py b/backend/dataall/modules/mlstudio/api/mutations.py index b14195f70..abcc3cc99 100644 --- a/backend/dataall/modules/mlstudio/api/mutations.py +++ b/backend/dataall/modules/mlstudio/api/mutations.py @@ -3,8 +3,6 @@ from dataall.modules.mlstudio.api.resolvers import ( create_sagemaker_studio_user, delete_sagemaker_studio_user, - create_sagemaker_studio_domain, - delete_environment_sagemaker_studio_domain ) createSagemakerStudioUser = gql.MutationField( @@ -31,24 +29,3 @@ type=gql.String, resolver=delete_sagemaker_studio_user, ) - -createMLStudioDomain = gql.MutationField( - name='createMLStudioDomain', - args=[ - gql.Argument( - name='input', - type=gql.NonNullableType(gql.Ref('NewStudioDomainInput')), - ) - ], - type=gql.Ref('SagemakerStudioDomain'), - resolver=create_sagemaker_studio_domain, -) - -deleteEnvironmentMLStudioDomain = gql.MutationField( - name='deleteEnvironmentMLStudioDomain', - args=[ - gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), - ], - type=gql.Boolean, - resolver=delete_environment_sagemaker_studio_domain, -) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index e4a41a73c..48c9350fa 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -28,16 +28,6 @@ def validate_user_creation_request(data): required(data, "environmentUri") required(data, "SamlAdminGroupName") - @staticmethod - def validate_domain_creation_request(data): - required = RequestValidator._required - if not data: - raise exceptions.RequiredParameter('data') - if not data.get('label'): - raise exceptions.RequiredParameter('name') - - required(data, "environmentUri") - @staticmethod def _required(data: dict, name: str): if not data.get(name): @@ -100,26 +90,6 @@ def delete_sagemaker_studio_user( ) -def create_sagemaker_studio_domain(context: Context, source, input: dict = None): - """Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS""" - RequestValidator.validate_domain_creation_request(input) - return SagemakerStudioService.create_sagemaker_studio_domain( - uri=input["environmentUri"], - data=input - ) - - -def delete_environment_sagemaker_studio_domain( - context, - source: SagemakerStudioUser, - environmentUri: str = None -): - RequestValidator.required_uri(environmentUri) - return SagemakerStudioService.delete_environment_sagemaker_studio_domain( - uri=environmentUri - ) - - def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None): RequestValidator.required_uri(environmentUri) return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri) diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 7446ef290..ca21df81d 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -103,15 +103,3 @@ ) ], ) - -SagemakerStudioDomainSearchResult = gql.ObjectType( - name='SagemakerStudioDomainSearchResult', - fields=[ - gql.Field(name='count', type=gql.Integer), - gql.Field(name='page', type=gql.Integer), - gql.Field(name='pages', type=gql.Integer), - gql.Field(name='hasNext', type=gql.Boolean), - gql.Field(name='hasPrevious', type=gql.Boolean), - gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioDomain)), - ], -) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index a124e164d..6beee7d49 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -10,33 +10,17 @@ from dataall.base.utils import slugify from dataall.base.db import paginate from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser -from dataall.core.environment.services.environment_resource_manager import EnvironmentResource -from dataall.core.environment.services.environment_service import EnvironmentService from dataall.base.utils.naming_convention import ( NamingConventionService, NamingConventionPattern, ) -class SageMakerStudioRepository(EnvironmentResource): +class SageMakerStudioRepository: """DAO layer for ML Studio""" _DEFAULT_PAGE = 1 _DEFAULT_PAGE_SIZE = 10 - @staticmethod - def update_env(session, environment): - current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") - domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) - previous_mlstudio_enabled = True if domain else False - return current_mlstudio_enabled != previous_mlstudio_enabled - - @staticmethod - def delete_env(session, environment): - domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment.environmentUri) - if domain: - session.delete(domain) - return True - @staticmethod def save_sagemaker_studio_user(session, user): """Save SageMaker Studio user to the database""" @@ -135,3 +119,11 @@ def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[Sagemak if not domain: return None return domain + + @staticmethod + def delete_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if domain: + session.delete(domain) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 12e70ca29..7b66031b2 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -18,6 +18,7 @@ from dataall.base.db import exceptions from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.core.environment.services.environment_resource_manager import EnvironmentResource from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser from dataall.base.aws.ec2_client import EC2 from dataall.base.aws.sts import SessionHelper @@ -58,6 +59,35 @@ def _session(): return get_context().db_engine.scoped_session() +class SagemakerStudioEnvironmentResource(EnvironmentResource): + @staticmethod + def count_resources(session, environment, group_uri) -> int: + return SageMakerStudioRepository.count_resources(session, environment, group_uri) + + @staticmethod + def create_env(session, environment, **kwargs): + enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + if enabled: + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + + @staticmethod + def update_env(session, environment, **kwargs): + current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) + previous_mlstudio_enabled = True if domain else False + if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + return True + elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + return True + return False + + @staticmethod + def delete_env(session, environment): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + + class SagemakerStudioService: """ Encapsulate the logic of interactions with sagemaker ml studio. @@ -142,68 +172,29 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak return sagemaker_studio_user @staticmethod - @has_tenant_permission(permissions.MANAGE_ENVIRONMENTS) - @has_resource_permission(permissions.UPDATE_ENVIRONMENT) - def create_sagemaker_studio_domain(*, uri: str, data: dict): - context = get_context() - with context.db_engine.scoped_session() as session: - environment = EnvironmentService.get_environment_by_uri(session, uri) - enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") - if not enabled: - raise exceptions.UnauthorizedOperation( - action=permissions.UPDATE_ENVIRONMENT, - message=f'ML Studio feature is disabled for the environment {environment.label}', - ) - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=environment.AwsAccountId, region=environment.region - ) - if data.get("vpcId", None): - SagemakerStudioService.check_mlstudio_domain_vpc( - account_id=environment.AwsAccountId, - region=environment.region, - cdk_look_up_role_arn=cdk_look_up_role_arn, - data=data - ) - data["vpcType"] = "imported" - elif EC2.check_default_vpc_exists( - AwsAccountId=environment.AwsAccountId, - region=environment.region, - role=cdk_look_up_role_arn, - ): - data["vpcType"] = "default" - else: - data["vpcType"] = "created" - - domain = SageMakerStudioRepository.create_sagemaker_studio_domain( - session=session, - username=get_context().username, - environment=environment, - data=data, - ) + def create_sagemaker_studio_domain(session, environment, data: dict = {}): + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=environment.AwsAccountId, region=environment.region + ) + if data.get("vpcId", None): + data["vpcType"] = "imported" + elif EC2.check_default_vpc_exists( + AwsAccountId=environment.AwsAccountId, + region=environment.region, + role=cdk_look_up_role_arn, + ): + data["vpcType"] = "default" + else: + data["vpcType"] = "created" + + domain = SageMakerStudioRepository.create_sagemaker_studio_domain( + session=session, + username=get_context().username, + environment=environment, + data=data, + ) return domain - @staticmethod - def check_mlstudio_domain_vpc(account_id: str, region: str, cdk_look_up_role_arn: str, data: dict): - if data.get("vpcId", None) and data.get("subnetIds", None): - EC2.check_vpc_exists( - AwsAccountId=account_id, - region=region, - role=cdk_look_up_role_arn, - vpc_id=data.get("vpcId", None), - subnet_ids=data.get('subnetIds', []), - ) - return True - - @staticmethod - @has_tenant_permission(permissions.MANAGE_ENVIRONMENTS) - @has_resource_permission(permissions.UPDATE_ENVIRONMENT) - def delete_environment_sagemaker_studio_domain(*, uri: str): - with _session() as session: - domain = SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=uri) - if domain: - session.delete(domain) - return True - @staticmethod def get_environment_sagemaker_studio_domain(*, environment_uri: str): with _session() as session: diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 4954d1076..a16cdfa7e 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -46,7 +46,6 @@ import { } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { - createMLStudioDomain, getOrganization, getTrustAccount, useClient, @@ -180,8 +179,8 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, - mlStudioVPCId: values.mlStudioVPCId, - mlStudioSubnetIds: values.mlStudioSubnetIds, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -203,19 +202,6 @@ const EnvironmentCreateForm = (props) => { }) ); if (!response.errors) { - if (values.mlStudiosEnabled === true) { - const response2 = await client.mutate( - createMLStudioDomain({ - environmentUri: response.data.createEnvironment.environmentUri, - label: values.label, - vpcId: values.mlStudioVPCId, - subnetIds: values.mlStudioSubnetIds - }) - ); - if (response2.errors) { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } setStatus({ success: true }); setSubmitting(false); enqueueSnackbar('Environment Created', { @@ -501,8 +487,8 @@ const EnvironmentCreateForm = (props) => { pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', resourcePrefix: 'dataall', - mlStudioVPCId: '', - mlStudioSubnetIds: [] + vpcId: '', + subnetIds: [] }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -526,7 +512,7 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { + subnetIds: Yup.array().when('vpcId', { is: (value) => !!value, then: Yup.array() .min(1) @@ -534,7 +520,7 @@ const EnvironmentCreateForm = (props) => { 'At least 1 Subnet Id required if VPC Id specified' ) }), - mlStudioVPCId: Yup.string().nullable(), + vpcId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() @@ -895,17 +881,13 @@ const EnvironmentCreateForm = (props) => { {...params} label="(Optional) ML Studio VPC ID" placeholder="(Optional) Bring your own VPC - Specify VPC ID" - name="mlStudioVPCId" + name="vpcId" fullWidth - error={Boolean( - touched.mlStudioVPCId && errors.mlStudioVPCId - )} - helperText={ - touched.mlStudioVPCId && errors.mlStudioVPCId - } + error={Boolean(touched.vpcId && errors.vpcId)} + helperText={touched.vpcId && errors.vpcId} onBlur={handleBlur} onChange={handleChange} - value={values.mlStudioVPCId} + value={values.vpcId} variant="outlined" /> @@ -913,18 +895,16 @@ const EnvironmentCreateForm = (props) => { { - setFieldValue('mlStudioSubnetIds', [...chip]); + setFieldValue('subnetIds', [...chip]); }} /> diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index 75eec8497..382575920 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -30,12 +30,7 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { - getEnvironmentMLStudioDomain, - createMLStudioDomain, - deleteEnvironmentMLStudioDomain, - useClient -} from 'services'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; import { getEnvironment, updateEnvironment } from '../services'; import { isAnyEnvironmentModuleEnabled, @@ -105,8 +100,8 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, - mlStudioVPCId: values.mlStudioVPCId, - mlStudioSubnetIds: values.mlStudioSubnetIds, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -129,35 +124,6 @@ const EnvironmentEditForm = (props) => { }) ); if (!response.errors) { - if ( - values.mlStudiosEnabled !== previousEnvMLStudioEnabled && - values.mlStudiosEnabled === true - ) { - const response2 = await client.mutate( - createMLStudioDomain({ - environmentUri: env.environmentUri, - label: values.label, - vpcId: values.mlStudioVPCId, - subnetIds: values.mlStudioSubnetIds - }) - ); - if (response2.errors) { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } - if ( - values.mlStudiosEnabled !== previousEnvMLStudioEnabled && - values.mlStudiosEnabled === false - ) { - const response2 = await client.mutate( - deleteEnvironmentMLStudioDomain({ - environmentUri: envMLStudioDomain.environmentUri - }) - ); - if (response2.errors) { - dispatch({ type: SET_ERROR, error: response.errors[0].message }); - } - } setStatus({ success: true }); setSubmitting(false); enqueueSnackbar('Environment updated', { @@ -268,8 +234,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], - mlStudioVPCId: envMLStudioDomain.vpcId || '', - mlStudioSubnetIds: envMLStudioDomain.subnetIds || [], + vpcId: envMLStudioDomain.vpcId || '', + subnetIds: envMLStudioDomain.subnetIds || [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -283,7 +249,7 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), - mlStudioSubnetIds: Yup.array().when('mlStudioVPCId', { + subnetIds: Yup.array().when('vpcId', { is: (value) => !!value, then: Yup.array() .min(1) @@ -291,7 +257,7 @@ const EnvironmentEditForm = (props) => { 'At least 1 Subnet Id required if VPC Id specified' ) }), - mlStudioVPCId: Yup.string().nullable(), + vpcId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -460,19 +426,13 @@ const EnvironmentEditForm = (props) => { disabled={previousEnvMLStudioEnabled} label="(Optional) ML Studio VPC ID" placeholder="(Optional) Bring your own VPC - Specify VPC ID" - name="mlStudioVPCId" + name="vpcId" fullWidth - error={Boolean( - touched.mlStudioVPCId && - errors.mlStudioVPCId - )} - helperText={ - touched.mlStudioVPCId && - errors.mlStudioVPCId - } + error={Boolean(touched.vpcId && errors.vpcId)} + helperText={touched.vpcId && errors.vpcId} onBlur={handleBlur} onChange={handleChange} - value={values.mlStudioVPCId} + value={values.vpcId} variant="outlined" /> @@ -481,20 +441,16 @@ const EnvironmentEditForm = (props) => { disabled={previousEnvMLStudioEnabled} fullWidth error={Boolean( - touched.mlStudioSubnetIds && - errors.mlStudioSubnetIds + touched.subnetIds && errors.subnetIds )} helperText={ - touched.mlStudioSubnetIds && - errors.mlStudioSubnetIds + touched.subnetIds && errors.subnetIds } variant="outlined" label="(Optional) ML Studio Subnet ID(s)" placeholder="(Optional) Bring your own VPC - Specify Subnet ID (Hit enter after typing value)" onChange={(chip) => { - setFieldValue('mlStudioSubnetIds', [ - ...chip - ]); + setFieldValue('subnetIds', [...chip]); }} /> diff --git a/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js b/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js deleted file mode 100644 index 3940e6748..000000000 --- a/frontend/src/services/graphql/MLStudio/createMLStudioDomain.js +++ /dev/null @@ -1,19 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const createMLStudioDomain = (input) => ({ - variables: { - input - }, - mutation: gql` - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - ` -}); diff --git a/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js b/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js deleted file mode 100644 index 7abdc7e9e..000000000 --- a/frontend/src/services/graphql/MLStudio/deleteEnvironmentMLStudioDomain.js +++ /dev/null @@ -1,12 +0,0 @@ -import { gql } from 'apollo-boost'; - -export const deleteEnvironmentMLStudioDomain = ({ environmentUri }) => ({ - variables: { - environmentUri - }, - mutation: gql` - mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { - deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) - } - ` -}); diff --git a/frontend/src/services/graphql/MLStudio/index.js b/frontend/src/services/graphql/MLStudio/index.js index 5ae789bc8..97d3de110 100644 --- a/frontend/src/services/graphql/MLStudio/index.js +++ b/frontend/src/services/graphql/MLStudio/index.js @@ -1,3 +1 @@ -export * from './createMLStudioDomain'; -export * from './deleteEnvironmentMLStudioDomain'; export * from './getEnvironmentMLStudioDomain'; diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 4b91e24cd..344d7e724 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -40,60 +40,7 @@ def check_vpc_exists(module_mocker): @pytest.fixture(scope='module') -def sagemaker_studio_domain(client, group, env_fixture) -> SagemakerStudioDomain: - response = client.query( - """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - }, - username='alice', - groups=[group.name], - ) - yield response.data.createMLStudioDomain - - -@pytest.fixture(scope='module') -def sagemaker_studio_domain_with_vpc(client, group, env_fixture) -> SagemakerStudioDomain: - response = client.query( - """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - 'vpcId': 'vpc-12345', - 'subnetIds': ['subnet-12345', 'subnet-67890'] - }, - username='alice', - groups=[group.name], - ) - - yield response.data.createMLStudioDomain - - -@pytest.fixture(scope='module') -def sagemaker_studio_user(client, tenant, group, env_fixture, sagemaker_studio_domain) -> SagemakerStudioUser: +def sagemaker_studio_user(client, tenant, group, env_with_mlstudio) -> SagemakerStudioUser: response = client.query( """ mutation createSagemakerStudioUser($input:NewSagemakerStudioUserInput){ @@ -112,7 +59,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture, sagemaker_studio_d input={ 'label': 'testcreate', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -121,7 +68,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture, sagemaker_studio_d @pytest.fixture(scope='module') -def multiple_sagemaker_studio_users(client, db, env_fixture, group): +def multiple_sagemaker_studio_users(client, db, env_with_mlstudio, group): for i in range(0, 10): response = client.query( """ @@ -141,7 +88,7 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): input={ 'label': f'test{i}', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -153,5 +100,142 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): ) assert ( response.data.createSagemakerStudioUser.environmentUri - == env_fixture.environmentUri + == env_with_mlstudio.environmentUri ) + +@pytest.fixture(scope='module') +def env_with_mlstudio(client, org_fixture, user, group, parameters=None, vpcId='', subnetIds=[]): + if not parameters: + parameters = {'mlStudiosEnabled': 'True'} + response = client.query( + """mutation CreateEnv($input:NewEnvironmentInput){ + createEnvironment(input:$input){ + organization{ + organizationUri + } + environmentUri + label + AwsAccountId + SamlGroupName + region + name + owner + parameters { + key + value + } + } + }""", + username=f'{user.username}', + groups=['testadmins'], + input={ + 'label': f'dev', + 'description': '', + 'organizationUri': org_fixture.organizationUri, + 'AwsAccountId': '111111111111', + 'tags': [], + 'region': 'us-east-1', + 'SamlGroupName': 'testadmins', + 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()], + 'vpcId': vpcId, + 'subnetIds': subnetIds + }, + ) + yield response.data.createEnvironment + + +# @pytest.fixture(scope='module') +# def env(client): +# cache = {} + +# def factory(org, envname, owner, group, account, region, vpcId='', subnetIds=[]): +# parameters = {"mlStudiosEnabled": "true"} + +# key = f"{org.organizationUri}{envname}{owner}{''.join(group or '-')}{account}{region}{vpcId}" +# if cache.get(key): +# return cache[key] +# response = client.query( +# """mutation CreateEnv($input:NewEnvironmentInput){ +# createEnvironment(input:$input){ +# organization{ +# organizationUri +# } +# environmentUri +# label +# AwsAccountId +# SamlGroupName +# region +# name +# owner +# parameters { +# key +# value +# } +# } +# }""", +# username=f'{owner}', +# groups=[group], +# input={ +# 'label': f'{envname}', +# 'description': 'test', +# 'organizationUri': org.organizationUri, +# 'AwsAccountId': account, +# 'tags': ['a', 'b', 'c'], +# 'region': f'{region}', +# 'SamlGroupName': f'{group}', +# 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()], +# 'vpcId': vpcId, +# 'subnetIds': subnetIds +# }, +# ) +# cache[key] = response.data.createEnvironment +# return cache[key] + +# yield factory + + +@pytest.fixture(scope='module', autouse=True) +def org(client): + cache = {} + + def factory(orgname, owner, group): + key = orgname + owner + group + if cache.get(key): + print(f'returning item from cached key {key}') + return cache.get(key) + response = client.query( + """mutation CreateOrganization($input:NewOrganizationInput){ + createOrganization(input:$input){ + organizationUri + label + name + owner + SamlGroupName + } + }""", + username=f'{owner}', + groups=[group], + input={ + 'label': f'{orgname}', + 'description': f'test', + 'tags': ['a', 'b', 'c'], + 'SamlGroupName': f'{group}', + }, + ) + cache[key] = response.data.createOrganization + return cache[key] + + yield factory + + +@pytest.fixture(scope='module') +def org_fixture(org, user, group): + org1 = org('testorg', user.username, group.name) + yield org1 + + +@pytest.fixture(scope='module') +def env_mlstudio_fixture(env, org_fixture, user, group, tenant): + env1 = env_with_mlstudio(org_fixture, 'dev', 'alice', 'testadmins', '111111111111', 'eu-west-1') + yield env1 + diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py index 87ef441b3..3d90b405a 100644 --- a/tests/modules/mlstudio/test_sagemaker_studio.py +++ b/tests/modules/mlstudio/test_sagemaker_studio.py @@ -1,47 +1,43 @@ from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -def test_create_sagemaker_studio_domain(sagemaker_studio_domain, env_fixture): - """Testing that the conftest sagemaker studio domain has been created correctly""" - assert sagemaker_studio_domain.label == 'testcreate-domain' - assert sagemaker_studio_domain.vpcType == 'created' - assert sagemaker_studio_domain.vpcId is None - assert len(sagemaker_studio_domain.subnetIds) == 0 - assert sagemaker_studio_domain.environmentUri == env_fixture.environmentUri - - -def test_create_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): +def test_create_sagemaker_studio_domain(db, client, org_fixture, env_with_mlstudio, user, group, vpcId="vpc-1234", subnetIds=["subnet"]): response = client.query( """ - mutation createMLStudioDomain($input: NewStudioDomainInput) { - createMLStudioDomain(input: $input) { - sagemakerStudioUri - environmentUri - label - vpcType - vpcId - subnetIds - } - } - """, - input={ - 'label': 'testcreate', - 'environmentUri': env_fixture.environmentUri, - }, - username='anonymoususer', - groups=[group2.name], + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, ) - assert 'Unauthorized' in response.errors[0].message + + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'created' + assert len(response.data.getEnvironmentMLStudioDomain.vpcId) == 0 + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 0 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri -def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_fixture): +def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_with_mlstudio): """Testing that the conftest sagemaker studio user has been created correctly""" assert sagemaker_studio_user.label == 'testcreate' assert sagemaker_studio_user.SamlAdminGroupName == group.name - assert sagemaker_studio_user.environmentUri == env_fixture.environmentUri + assert sagemaker_studio_user.environmentUri == env_with_mlstudio.environmentUri -def test_list_sagemaker_studio_users(client, env_fixture, db, group, multiple_sagemaker_studio_users): +def test_list_sagemaker_studio_users(client, db, group, multiple_sagemaker_studio_users): response = client.query( """ query listSagemakerStudioUsers($filter:SagemakerStudioUserFilter!){ @@ -101,8 +97,47 @@ def test_delete_sagemaker_studio_user( ) assert not n +def update_env_query(): + query = """ + mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ + updateEnvironment(environmentUri:$environmentUri,input:$input){ + organization{ + organizationUri + } + label + AwsAccountId + region + SamlGroupName + owner + tags + resourcePrefix + parameters { + key + value + } + } + } + """ + return query + +def test_update_env_delete_domain(client, org_fixture, env_with_mlstudio, group, group2): + response = client.query( + update_env_query(), + username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'DEV', + 'tags': [], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'False' + } + ], + }, + groups=[group.name], + ) -def test_get_sagemaker_studio_domain(client, env_fixture, sagemaker_studio_domain): response = client.query( """ query getEnvironmentMLStudioDomain($environmentUri: String) { @@ -120,24 +155,30 @@ def test_get_sagemaker_studio_domain(client, env_fixture, sagemaker_studio_domai } } """, - environmentUri=env_fixture.environmentUri, + environmentUri=env_with_mlstudio.environmentUri, ) - print(response.data) - assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri == sagemaker_studio_domain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain is None -def test_delete_sagemaker_studio_domain(client, env_fixture, group): +def test_update_env_create_domain_with_vpc(db, client, org_fixture, env_with_mlstudio, user, group): response = client.query( - """ - mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { - deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) - } - """, - environmentUri=env_fixture.environmentUri, + update_env_query(), username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'dev', + 'tags': [], + 'vpcId': "vpc-12345", + 'subnetIds': ['subnet-12345', 'subnet-67890'], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'True' + } + ], + }, groups=[group.name], ) - assert response.data.deleteEnvironmentMLStudioDomain response = client.query( """ @@ -156,30 +197,13 @@ def test_delete_sagemaker_studio_domain(client, env_fixture, group): } } """, - environmentUri=env_fixture.environmentUri + environmentUri=env_with_mlstudio.environmentUri, ) - assert response.data.getEnvironmentMLStudioDomain is None - - -def test_create_sagemaker_studio_domain_with_vpc(sagemaker_studio_domain_with_vpc, env_fixture): - """Testing that the conftest sagemaker studio domain has been created correctly""" - assert sagemaker_studio_domain_with_vpc.label == 'testcreate-domain' - assert sagemaker_studio_domain_with_vpc.vpcType == 'imported' - assert sagemaker_studio_domain_with_vpc.vpcId == 'vpc-12345' - assert sagemaker_studio_domain_with_vpc.subnetIds == ['subnet-12345', 'subnet-67890'] - assert sagemaker_studio_domain_with_vpc.environmentUri == env_fixture.environmentUri + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'imported' + assert response.data.getEnvironmentMLStudioDomain.vpcId == 'vpc-12345' + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 2 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri -def test_delete_sagemaker_studio_domain_unauthorized(client, env_fixture, group2): - response = client.query( - """ - mutation deleteEnvironmentMLStudioDomain($environmentUri: String!) { - deleteEnvironmentMLStudioDomain(environmentUri: $environmentUri) - } - """, - environmentUri=env_fixture.environmentUri, - username='anonymoususer', - groups=[group2.name], - ) - - assert 'Unauthorized' in response.errors[0].message From 9287d163434648abfc6e865fa780d8ef74d2a547 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Wed, 6 Dec 2023 21:10:28 -0500 Subject: [PATCH 30/38] Clean up tests --- tests/modules/mlstudio/conftest.py | 50 ------------------------------ 1 file changed, 50 deletions(-) diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 344d7e724..d1fffb2cf 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -144,56 +144,6 @@ def env_with_mlstudio(client, org_fixture, user, group, parameters=None, vpcId=' yield response.data.createEnvironment -# @pytest.fixture(scope='module') -# def env(client): -# cache = {} - -# def factory(org, envname, owner, group, account, region, vpcId='', subnetIds=[]): -# parameters = {"mlStudiosEnabled": "true"} - -# key = f"{org.organizationUri}{envname}{owner}{''.join(group or '-')}{account}{region}{vpcId}" -# if cache.get(key): -# return cache[key] -# response = client.query( -# """mutation CreateEnv($input:NewEnvironmentInput){ -# createEnvironment(input:$input){ -# organization{ -# organizationUri -# } -# environmentUri -# label -# AwsAccountId -# SamlGroupName -# region -# name -# owner -# parameters { -# key -# value -# } -# } -# }""", -# username=f'{owner}', -# groups=[group], -# input={ -# 'label': f'{envname}', -# 'description': 'test', -# 'organizationUri': org.organizationUri, -# 'AwsAccountId': account, -# 'tags': ['a', 'b', 'c'], -# 'region': f'{region}', -# 'SamlGroupName': f'{group}', -# 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()], -# 'vpcId': vpcId, -# 'subnetIds': subnetIds -# }, -# ) -# cache[key] = response.data.createEnvironment -# return cache[key] - -# yield factory - - @pytest.fixture(scope='module', autouse=True) def org(client): cache = {} From 513e69171bf26730584cba6211678a0c0bfd47a8 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 10:30:35 -0500 Subject: [PATCH 31/38] Add default vpc info, SAML Group Name to domain, and Fix EnvironmentMLStudio View --- backend/dataall/base/aws/ec2_client.py | 13 +- .../mlstudio/cdk/mlstudio_extension.py | 116 ++++++++---------- .../mlstudio/db/mlstudio_repositories.py | 1 + .../mlstudio/services/mlstudio_service.py | 39 ++++-- ...f5de322f_update_sagemaker_studio_domain.py | 31 +++-- .../components/EnvironmentMLStudio.js | 15 +-- 6 files changed, 114 insertions(+), 101 deletions(-) diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py index 23d290a7e..06bd62c7a 100644 --- a/backend/dataall/base/aws/ec2_client.py +++ b/backend/dataall/base/aws/ec2_client.py @@ -23,9 +23,20 @@ def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): vpcs = response['Vpcs'] log.info(f"Default VPCs response: {vpcs}") if vpcs: - return True + vpc_id = vpcs[0]['VpcId'] + subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role) + if subnetIds: + return vpc_id, subnetIds return False + @staticmethod + def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None): + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_subnets( + Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}] + ) + return [subnet['SubnetId'] for subnet in response['Subnets']] + @staticmethod def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): try: diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 1a0da209b..495cad19b 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -14,10 +14,8 @@ ) from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository -from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) @@ -38,77 +36,61 @@ def extent(setup: EnvironmentSetup): if domain.vpcId and domain.subnetIds: logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) + vpc_id = domain.vpcId subnet_ids = domain.subnetIds security_groups = [] else: - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=_environment.AwsAccountId, region=_environment.region + logger.info("VPC not specified or not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, ) - existing_default_vpc = EC2.check_default_vpc_exists( - AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, ) - if existing_default_vpc: - logger.info("Using default VPC for Sagemaker Studio domain") - # Use default VPC - initial configuration (to be migrated) - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] - subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] - security_groups = [] - else: - logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, - ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) - ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", - security_group_name=domain.sagemakerStudioDomainName, - ) - - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - vpc_id = vpc.vpc_id + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + vpc_id = vpc.vpc_id sagemaker_domain_role = iam.Role( setup, diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 6beee7d49..21847b6ef 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -82,6 +82,7 @@ def create_sagemaker_studio_domain(session, username, environment, data): owner=username, description=data.get('description', 'No description provided'), tags=data.get('tags', []), + SamlGroupName=environment.SamlGroupName, environmentUri=environment.environmentUri, AWSAccountId=environment.AwsAccountId, region=environment.region, diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 7b66031b2..729c3971e 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -81,6 +81,9 @@ def update_env(session, environment, **kwargs): elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) return True + elif current_mlstudio_enabled: + SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) + return True return False @staticmethod @@ -172,20 +175,38 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak return sagemaker_studio_user @staticmethod - def create_sagemaker_studio_domain(session, environment, data: dict = {}): + def update_sagemaker_studio_domain(environment, domain, data): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) + domain.vpcType = data.get('vpcType') + if data.get('vpcId'): + domain.vpcId = data.get('vpcId') + if data.get('subnetIds'): + domain.subnetIds = data.get('subnetIds') + + @staticmethod + def _update_sagemaker_studio_domain_vpc(account_id, region, data={}): cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=environment.AwsAccountId, region=environment.region + accountid=account_id, region=region ) if data.get("vpcId", None): data["vpcType"] = "imported" - elif EC2.check_default_vpc_exists( - AwsAccountId=environment.AwsAccountId, - region=environment.region, - role=cdk_look_up_role_arn, - ): - data["vpcType"] = "default" else: - data["vpcType"] = "created" + response = EC2.check_default_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + ) + if response: + vpcId, subnetIds = response + data["vpcType"] = "default" + data["vpcId"] = vpcId + data["subnetIds"] = subnetIds + else: + data["vpcType"] = "created" + + @staticmethod + def create_sagemaker_studio_domain(session, environment, data: dict = {}): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) domain = SageMakerStudioRepository.create_sagemaker_studio_domain( session=session, diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 9c31cc7a2..6eae4b073 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -91,6 +91,7 @@ def upgrade(): op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False)) op.create_foreign_key( "fk_sagemaker_studio_domain_env_uri", @@ -112,18 +113,23 @@ def upgrade(): Environment.environmentUri == param.environmentUri ).first() - domain = SagemakerStudioDomain( - label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - owner=env.owner, - description='No description provided', - environmentUri=env.environmentUri, - AWSAccountId=env.AwsAccountId, - region=env.region, - DefaultDomainRoleName="RoleSagemakerStudioUsers", - sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - vpcType="unknown" - ) - session.add(domain) + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env.environmentUri + ).first() + if not domain: + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown", + SamlGroupName=env.SamlGroupName + ) + session.add(domain) session.flush() session.commit() print("Fill of sagemaker_studio_domain table is done") @@ -167,6 +173,7 @@ def downgrade(): op.drop_column("sagemaker_studio_domain", "vpcType") op.drop_column("sagemaker_studio_domain", "vpcId") op.drop_column("sagemaker_studio_domain", "subnetIds") + op.drop_column("sagemaker_studio_domain", "SamlGroupName") op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 744a0350c..79427116f 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -12,7 +12,7 @@ import { import PropTypes from 'prop-types'; import React, { useCallback, useEffect, useState } from 'react'; -import { RefreshTableMenu, ObjectMetadata } from 'design'; +import { RefreshTableMenu } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; import { getEnvironmentMLStudioDomain, useClient } from 'services'; @@ -92,7 +92,6 @@ export const EnvironmentMLStudio = ({ environment }) => { - SageMaker ML Studio Domain Name @@ -118,7 +117,8 @@ export const EnvironmentMLStudio = ({ environment }) => { {mlStudioDomain.vpcType} - {mlStudioDomain.vpcType === 'imported' && ( + {(mlStudioDomain.vpcId === 'imported' || + mlStudioDomain.vpcId === 'default') && ( <> @@ -147,15 +147,6 @@ export const EnvironmentMLStudio = ({ environment }) => { )} - - - )} From cc77b4e031bb1a33b9f7d4ee8b4e695d1a234b7c Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 11:40:19 -0500 Subject: [PATCH 32/38] Fix downgrade migration and add saml group name to models --- .../dataall/modules/mlstudio/db/mlstudio_models.py | 1 + .../71a5f5de322f_update_sagemaker_studio_domain.py | 12 ------------ 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 28ff62e9e..a4c93a2fa 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -21,6 +21,7 @@ class SagemakerStudioDomain(Resource, Base): AWSAccountId = Column(String, nullable=False) DefaultDomainRoleName = Column(String, nullable=False) region = Column(String, default='eu-west-1') + SamlGroupName = Column(String, nullable=False) vpcType = Column(String, nullable=True) vpcId = Column(String, nullable=True) subnetIds = Column(ARRAY(String), nullable=True) diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 6eae4b073..18c6e2ac1 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -149,18 +149,6 @@ def downgrade(): if has_table('sagemaker_studio_domain', engine): print("Updating of sagemaker_studio_domain table...") - op.alter_column( - 'sagemaker_studio_domain', - 'sagemakerStudioDomainID', - nullable=False, - existing_type=sa.String() - ) - op.alter_column( - 'sagemaker_studio_domain', - 'SagemakerStudioStatus', - nullable=False, - existing_type=sa.String() - ) op.alter_column( 'sagemaker_studio_domain', 'DefaultDomainRoleName', From 4bbba28dfb04c0145c8fd3f444c6500805e89495 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 12:00:23 -0500 Subject: [PATCH 33/38] Fix tests to add samlgroupname --- tests/modules/mlstudio/cdk/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modules/mlstudio/cdk/conftest.py b/tests/modules/mlstudio/cdk/conftest.py index 718f2023f..2c6f1eddd 100644 --- a/tests/modules/mlstudio/cdk/conftest.py +++ b/tests/modules/mlstudio/cdk/conftest.py @@ -36,8 +36,8 @@ def sgm_studio_domain(db, env_fixture: Environment) -> SagemakerStudioDomain: SagemakerStudioStatus="PENDING", DefaultDomainRoleName="DefaultMLStudioRole", sagemakerStudioDomainName="DomainName", - vpcType="created" + vpcType="created", + SamlGroupName=env_fixture.SamlGroupName, ) session.add(sm_domain) yield sm_domain - \ No newline at end of file From cd4fe7f4f970ab591854cc62ef7e7a6a594cb868 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 12:38:54 -0500 Subject: [PATCH 34/38] fix migration script downgrade then upgrade --- .../versions/71a5f5de322f_update_sagemaker_studio_domain.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 18c6e2ac1..298a481ce 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -50,6 +50,7 @@ class SagemakerStudioDomain(Resource, Base): AWSAccountId = Column(String, nullable=False) DefaultDomainRoleName = Column(String, nullable=False) region = Column(String, default='eu-west-1') + SamlGroupName = Column(String, nullable=False) vpcType = Column(String, nullable=True) @@ -148,6 +149,9 @@ def downgrade(): session = orm.Session(bind=bind) if has_table('sagemaker_studio_domain', engine): + print("deleting sagemaker studio domain entries...") + session.query(SagemakerStudioDomain).delete() + print("Updating of sagemaker_studio_domain table...") op.alter_column( 'sagemaker_studio_domain', From 9417ab9d020becf133e5a804b27e73ef66355f62 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 14:35:58 -0500 Subject: [PATCH 35/38] Final fixes --- .../mlstudio/cdk/mlstudio_extension.py | 116 ++++++++++-------- .../mlstudio/services/mlstudio_service.py | 3 +- ...f5de322f_update_sagemaker_studio_domain.py | 1 + .../components/EnvironmentMLStudio.js | 111 ++++++++--------- 4 files changed, 124 insertions(+), 107 deletions(-) diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 495cad19b..1a0da209b 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -14,8 +14,10 @@ ) from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) @@ -36,61 +38,77 @@ def extent(setup: EnvironmentSetup): if domain.vpcId and domain.subnetIds: logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') - vpc_id = domain.vpcId + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) subnet_ids = domain.subnetIds security_groups = [] else: - logger.info("VPC not specified or not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=_environment.AwsAccountId, region=_environment.region ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) - ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", - security_group_name=domain.sagemakerStudioDomainName, + existing_default_vpc = EC2.check_default_vpc_exists( + AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn ) + if existing_default_vpc: + logger.info("Using default VPC for Sagemaker Studio domain") + # Use default VPC - initial configuration (to be migrated) + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + security_groups = [] + else: + logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, + ) + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, + ) + + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - vpc_id = vpc.vpc_id + vpc_id = vpc.vpc_id sagemaker_domain_role = iam.Role( setup, diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 729c3971e..75328d530 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -75,13 +75,14 @@ def update_env(session, environment, **kwargs): current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) previous_mlstudio_enabled = True if domain else False + vpcType = domain.vpcType if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) return True elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) return True - elif current_mlstudio_enabled: + elif current_mlstudio_enabled and vpcType == "unknown": SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) return True return False diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 298a481ce..28bf39957 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -29,6 +29,7 @@ class Environment(Resource, Base): environmentUri = Column(String, primary_key=True) AwsAccountId = Column(Boolean) region = Column(Boolean) + SamlGroupName = Column(String) class EnvironmentParameter(Base): diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js index 79427116f..44dac97b9 100644 --- a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -61,7 +61,7 @@ export const EnvironmentMLStudio = ({ environment }) => { } - title={ML Studio Domain} + title={ML Studio Domain Information} /> { ) : ( - - - - - SageMaker ML Studio Domain Name - - - {mlStudioDomain.sagemakerStudioDomainName} - - - - - SageMaker ML Studio Default Execution Role - - - arn:aws:iam::{environment.AwsAccountId}:role/ - {mlStudioDomain.DefaultDomainRoleName} - - - - - Domain VPC Type - - - {mlStudioDomain.vpcType} - - - {(mlStudioDomain.vpcId === 'imported' || - mlStudioDomain.vpcId === 'default') && ( - <> - - - Domain VPC Id - - - {mlStudioDomain.vpcId} - - - - - Domain Subnet Ids - - - {mlStudioDomain.subnetIds?.map((subnet) => ( - - ))} - - - - )} - + + + SageMaker ML Studio Domain Name + + + {mlStudioDomain.sagemakerStudioDomainName} + + + + + SageMaker ML Studio Default Execution Role + + + arn:aws:iam::{environment.AwsAccountId}:role/ + {mlStudioDomain.DefaultDomainRoleName} + + + + + Domain VPC Type + + + {mlStudioDomain.vpcType} + + + {(mlStudioDomain.vpcType === 'imported' || + mlStudioDomain.vpcType === 'default') && ( + <> + + + Domain VPC Id + + + {mlStudioDomain.vpcId} + + + + + Domain Subnet Ids + + + {mlStudioDomain.subnetIds?.map((subnet) => ( + + ))} + + + + )} )} From 33b035f8c48095c93a282b5b7ea281c7f5357f2b Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 14:53:50 -0500 Subject: [PATCH 36/38] Fix unknown vpc type integration tests --- backend/dataall/modules/mlstudio/services/mlstudio_service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 75328d530..3738c118d 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -75,14 +75,13 @@ def update_env(session, environment, **kwargs): current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) previous_mlstudio_enabled = True if domain else False - vpcType = domain.vpcType if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) return True elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) return True - elif current_mlstudio_enabled and vpcType == "unknown": + elif current_mlstudio_enabled and domain and domain.vpcType == "unknown": SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) return True return False From 19c194e316154b0ae937f830d821662904cf5707 Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 14:54:22 -0500 Subject: [PATCH 37/38] lint --- .../versions/71a5f5de322f_update_sagemaker_studio_domain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index 28bf39957..a3ac794f3 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -152,7 +152,7 @@ def downgrade(): if has_table('sagemaker_studio_domain', engine): print("deleting sagemaker studio domain entries...") session.query(SagemakerStudioDomain).delete() - + print("Updating of sagemaker_studio_domain table...") op.alter_column( 'sagemaker_studio_domain', From 31f250ab03565253a50739dd1242562aadf64d0c Mon Sep 17 00:00:00 2001 From: Noah Paige Date: Thu, 7 Dec 2023 15:57:33 -0500 Subject: [PATCH 38/38] only use domain RDS record on imported vpc for mlstudio extension stack --- backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 1a0da209b..49082ccfb 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -36,7 +36,7 @@ def extent(setup: EnvironmentSetup): sagemaker_principals = [setup.default_role] + setup.group_roles logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..') - if domain.vpcId and domain.subnetIds: + if domain.vpcId and domain.subnetIds and domain.vpcType == 'imported': logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) subnet_ids = domain.subnetIds