Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Byo vpc mlstudio #894

Merged
merged 39 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
045f701
Start work on byo vpc ml studio
noah-paige Nov 21, 2023
86de123
Add Payload
noah-paige Nov 24, 2023
866548c
Add backend logic deploy mlstudio custom vpc
noah-paige Nov 28, 2023
b717f25
lint
noah-paige Nov 28, 2023
f110930
Merge branch 'os-main' into byo-vpc-mlstudio
noah-paige Nov 29, 2023
0b2d6a7
Test new ML Studio Domain Views
noah-paige Nov 29, 2023
dc8a412
Test new ML Studio Domain Views - add missing files
noah-paige Nov 29, 2023
973386a
Test new ML Studio Domain Views - add missing files
noah-paige Nov 29, 2023
2de32a4
Write backend resolvers for ml studio domain apis
noah-paige Nov 30, 2023
9212012
Fixes to API params and permissions checks
noah-paige Nov 30, 2023
56099d6
Add Stack Deploy on create/delete studio domain
noah-paige Nov 30, 2023
ecb9be0
Fix Migration script and clean up naming and lint checks
noah-paige Nov 30, 2023
3280a05
Create Studio Domain on Env Create/Update, rework Frontend Views
noah-paige Dec 1, 2023
6909b9d
Remove unused API list domains
noah-paige Dec 1, 2023
50a90b3
add tests mlstudio domain apis
noah-paige Dec 1, 2023
4eb4858
Clean Up
noah-paige Dec 1, 2023
2acfdfe
Clean up get studio domain and update migration script
noah-paige Dec 4, 2023
4215ff4
Edit text when ML Studio disabled
noah-paige Dec 4, 2023
95db11b
Fix coverage tests
noah-paige Dec 4, 2023
2dc0e1a
Revert migration script
noah-paige Dec 4, 2023
5cb1d8b
Handle null values Edit form
noah-paige Dec 5, 2023
32bc04d
Move EC2 to base, clean up unused code, move APIs to shared, add dele…
noah-paige Dec 6, 2023
f9e522d
Move EC2 to base, clean up unused code, move APIs to shared, add dele…
noah-paige Dec 6, 2023
cec2afe
Move EC2 to base, clean up unused code, move APIs to shared, add dele…
noah-paige Dec 6, 2023
47241c6
fix import paths
noah-paige Dec 6, 2023
c50e345
Fix tests patch
noah-paige Dec 6, 2023
f2a8350
delete domain before deleting env
noah-paige Dec 6, 2023
e5cc43f
Add delete ml studio domain as part of delete env backend env resource
noah-paige Dec 6, 2023
3aa03b3
Fix delete_env
noah-paige Dec 6, 2023
57f0e8b
Change method of create / delete ml studio to not call 2 APIs on envi…
noah-paige Dec 6, 2023
9287d16
Clean up tests
noah-paige Dec 7, 2023
513e691
Add default vpc info, SAML Group Name to domain, and Fix EnvironmentM…
noah-paige Dec 7, 2023
cc77b4e
Fix downgrade migration and add saml group name to models
noah-paige Dec 7, 2023
4bbba28
Fix tests to add samlgroupname
noah-paige Dec 7, 2023
cd4fe7f
fix migration script downgrade then upgrade
noah-paige Dec 7, 2023
9417ab9
Final fixes
noah-paige Dec 7, 2023
33b035f
Fix unknown vpc type integration tests
noah-paige Dec 7, 2023
19c194e
lint
noah-paige Dec 7, 2023
31f250a
only use domain RDS record on imported vpc for mlstudio extension stack
noah-paige Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions backend/dataall/base/aws/ec2_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging

from dataall.base.aws.sts import SessionHelper
from botocore.exceptions import ClientError

log = logging.getLogger(__name__)


class EC2:

@staticmethod
def get_client(account_id: str, region: str, role=None):
session = SessionHelper.remote_session(accountid=account_id, role=role)
return session.client('ec2', region_name=region)

@staticmethod
def check_default_vpc_exists(AwsAccountId: str, region: str, role=None):
log.info("Check that default VPC exists..")
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_vpcs(
Filters=[{'Name': 'isDefault', 'Values': ['true']}]
)
vpcs = response['Vpcs']
log.info(f"Default VPCs response: {vpcs}")
if vpcs:
vpc_id = vpcs[0]['VpcId']
subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role)
if subnetIds:
return vpc_id, subnetIds
return False

@staticmethod
def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None):
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_subnets(
Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]
)
return [subnet['SubnetId'] for subnet in response['Subnets']]

@staticmethod
def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]):
try:
ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = ec2.describe_vpcs(VpcIds=[vpc_id])
except ClientError as e:
log.exception(f'VPC Id {vpc_id} Not Found: {e}')
raise Exception(f'VPCNotFound: {vpc_id}')

try:
if subnet_ids:
response = ec2.describe_subnets(
Filters=[
{
'Name': 'vpc-id',
'Values': [vpc_id]
},
],
SubnetIds=subnet_ids
)
except ClientError as e:
log.exception(f'Subnet Id {subnet_ids} Not Found: {e}')
raise Exception(f'VPCSubnetsNotFound: {subnet_ids}')

if not subnet_ids or len(response['Subnets']) != len(subnet_ids):
raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}')
1 change: 1 addition & 0 deletions backend/dataall/base/utils/naming_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum):
GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63}
GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52}
NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63}
OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27}
OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31}
Expand Down
15 changes: 6 additions & 9 deletions backend/dataall/core/environment/api/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@
gql.Argument('description', gql.String),
gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)),
gql.Argument('region', gql.NonNullableType(gql.String)),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))

gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand All @@ -45,11 +43,10 @@
gql.Argument('description', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('SamlGroupName', gql.String),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand Down
28 changes: 23 additions & 5 deletions backend/dataall/core/environment/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dataall.core.stacks.aws.cloudformation import CloudFormation
from dataall.core.stacks.db.stack_repositories import Stack
from dataall.core.vpc.db.vpc_repositories import Vpc
from dataall.base.aws.ec2_client import EC2
from dataall.base.db import exceptions
from dataall.core.permissions import permissions
from dataall.base.feature_toggle_checker import is_feature_enabled
Expand All @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs):
return True if ssm_param == "True" else False


def check_environment(context: Context, source, account_id, region):
def check_environment(context: Context, source, account_id, region, data):
""" Checks necessary resources for environment deployment.
- Check CDKToolkit exists in Account assuming cdk_look_up_role
- Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False
Expand Down Expand Up @@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region):
action='CHECK_PIVOT_ROLE',
message='Pivot Role has not been created in the Environment AWS Account',
)
mlStudioEnabled = None
for parameter in data.get("parameters", []):
if parameter['key'] == 'mlStudiosEnabled':
mlStudioEnabled = parameter['value']

if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []):
log.info("Check if ML Studio VPC Exists in the Account")
EC2.check_vpc_exists(
AwsAccountId=account_id,
region=region,
role=cdk_look_up_role_arn,
vpc_id=data.get("vpcId", None),
subnet_ids=data.get('subnetIds', []),
)
noah-paige marked this conversation as resolved.
Show resolved Hide resolved

return cdk_role_name


def create_environment(context: Context, source, input=None):
def create_environment(context: Context, source, input={}):
if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups:
raise exceptions.UnauthorizedOperation(
action=permissions.LINK_ENVIRONMENT,
Expand All @@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None):
with context.engine.scoped_session() as session:
cdk_role_name = check_environment(context, source,
account_id=input.get('AwsAccountId'),
region=input.get('region')
region=input.get('region'),
data=input
)

input['cdk_role_name'] = cdk_role_name
env = EnvironmentService.create_environment(
session=session,
Expand Down Expand Up @@ -119,7 +136,8 @@ def update_environment(
environment = EnvironmentService.get_environment_by_uri(session, environmentUri)
cdk_role_name = check_environment(context, source,
account_id=environment.AwsAccountId,
region=environment.region
region=environment.region,
data=input
)

previous_resource_prefix = environment.resourcePrefix
Expand All @@ -130,7 +148,7 @@ def update_environment(
data=input,
)

if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment):
if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input):
stack_helper.deploy_stack(targetUri=environment.environmentUri)

return environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def delete_env(session, environment):
pass

@staticmethod
def update_env(session, environment):
def create_env(session, environment, **kwargs):
pass

@staticmethod
def update_env(session, environment, **kwargs):
return False

@staticmethod
Expand All @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int:
return counter

@classmethod
def deploy_updated_stack(cls, session, prev_prefix, environment):
def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs):
deploy_stack = prev_prefix != environment.resourcePrefix
for resource in cls._resources:
deploy_stack |= resource.update_env(session, environment)
deploy_stack |= resource.update_env(session, environment, **kwargs)

return deploy_stack

Expand All @@ -51,6 +55,11 @@ def delete_env(cls, session, environment):
for resource in cls._resources:
resource.delete_env(session, environment)

@classmethod
def create_env(cls, session, environment, **kwargs):
for resource in cls._resources:
resource.create_env(session, environment, **kwargs)

@classmethod
def count_consumption_role_resources(cls, session, role_uri):
counter = 0
Expand Down
24 changes: 1 addition & 23 deletions backend/dataall/core/environment/services/environment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None):
session.commit()

EnvironmentService._update_env_parameters(session, env, data)
EnvironmentResourceManager.create_env(session, env, data=data)

env.EnvironmentDefaultBucketName = NamingConventionService(
target_uri=env.environmentUri,
Expand Down Expand Up @@ -98,29 +99,6 @@ def create_environment(session, uri, data=None):
env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn']
env.EnvironmentDefaultIAMRoleImported = True

if data.get('vpcId'):
vpc = Vpc(
environmentUri=env.environmentUri,
region=env.region,
AwsAccountId=env.AwsAccountId,
VpcId=data.get('vpcId'),
privateSubnetIds=data.get('privateSubnetIds', []),
publicSubnetIds=data.get('publicSubnetIds', []),
SamlGroupName=data['SamlGroupName'],
owner=context.username,
label=f"{env.name}-{data.get('vpcId')}",
name=f"{env.name}-{data.get('vpcId')}",
default=True,
)
session.add(vpc)
session.commit()
ResourcePolicy.attach_resource_policy(
session=session,
group=data['SamlGroupName'],
permissions=permissions.NETWORK_ALL,
resource_uri=vpc.vpcUri,
resource_type=Vpc.__name__,
)
env_group = EnvironmentGroup(
environmentUri=env.environmentUri,
groupUri=data['SamlGroupName'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int:
)

@staticmethod
def update_env(session, environment):
def update_env(session, environment, **kwargs):
return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled")

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion backend/dataall/modules/mlstudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from dataall.base.loader import ImportMode, ModuleInterface
from dataall.core.stacks.db.target_type_repositories import TargetType
from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository
from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource
from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager

log = logging.getLogger(__name__)

Expand All @@ -20,6 +21,8 @@ def __init__(self):
from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER
TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER)

EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource())

log.info("API of sagemaker mlstudio has been imported")


Expand Down
10 changes: 10 additions & 0 deletions backend/dataall/modules/mlstudio/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
get_sagemaker_studio_user,
list_sagemaker_studio_users,
get_sagemaker_studio_user_presigned_url,
get_environment_sagemaker_studio_domain
)

getSagemakerStudioUser = gql.QueryField(
Expand Down Expand Up @@ -34,3 +35,12 @@
type=gql.String,
resolver=get_sagemaker_studio_user_presigned_url,
)

getEnvironmentMLStudioDomain = gql.QueryField(
name='getEnvironmentMLStudioDomain',
args=[
gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
],
type=gql.Ref('SagemakerStudioDomain'),
resolver=get_environment_sagemaker_studio_domain,
)
9 changes: 7 additions & 2 deletions backend/dataall/modules/mlstudio/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def required_uri(uri):
raise exceptions.RequiredParameter('URI')

@staticmethod
def validate_creation_request(data):
def validate_user_creation_request(data):
required = RequestValidator._required
if not data:
raise exceptions.RequiredParameter('data')
Expand All @@ -36,7 +36,7 @@ def _required(data: dict, name: str):

def create_sagemaker_studio_user(context: Context, source, input: dict = None):
"""Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS"""
RequestValidator.validate_creation_request(input)
RequestValidator.validate_user_creation_request(input)
request = SagemakerStudioCreationRequest.from_dict(input)
return SagemakerStudioService.create_sagemaker_studio_user(
uri=input["environmentUri"],
Expand Down Expand Up @@ -90,6 +90,11 @@ def delete_sagemaker_studio_user(
)


def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None):
RequestValidator.required_uri(environmentUri)
return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri)


def resolve_user_role(context: Context, source: SagemakerStudioUser):
"""
Resolves the role of the current user in reference with the SageMaker Studio User
Expand Down
24 changes: 24 additions & 0 deletions backend/dataall/modules/mlstudio/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,27 @@
gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)),
],
)

SagemakerStudioDomain = gql.ObjectType(
name='SagemakerStudioDomain',
fields=[
gql.Field(name='sagemakerStudioUri', type=gql.ID),
gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)),
gql.Field(name='sagemakerStudioDomainName', type=gql.String),
gql.Field(name='DefaultDomainRoleName', type=gql.String),
gql.Field(name='label', type=gql.String),
gql.Field(name='name', type=gql.String),
gql.Field(name='vpcType', type=gql.String),
gql.Field(name='vpcId', type=gql.String),
gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)),
gql.Field(name='owner', type=gql.String),
gql.Field(name='created', type=gql.String),
gql.Field(name='updated', type=gql.String),
gql.Field(name='deleted', type=gql.String),
gql.Field(
name='environment',
type=gql.Ref('Environment'),
resolver=resolve_environment,
)
],
)
27 changes: 0 additions & 27 deletions backend/dataall/modules/mlstudio/aws/ec2_client.py

This file was deleted.

Loading
Loading