Skip to content

Commit

Permalink
Merge branch 'main' into feature/github-security-checks-frontend
Browse files Browse the repository at this point in the history
# Conflicts:
#	frontend/yarn.lock
  • Loading branch information
dlpzx committed Sep 4, 2023
2 parents 6fb86df + 56c5835 commit 9a51757
Show file tree
Hide file tree
Showing 121 changed files with 16,331 additions and 2,771 deletions.
20 changes: 20 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,24 @@
### Relates
- <URL or Ticket>

### Security
Please answer the questions below briefly where applicable, or write `N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this includes
fetching data from storage outside the application (e.g. a database, an S3 bucket)?
- Is the input sanitized?
- What precautions are you taking before deserializing the data you consume?
- Is injection prevented by parametrizing queries?
- Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
- Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
- Do you use a standard proven implementations?
- Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
- Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
Binary file modified UserGuide.pdf
Binary file not shown.
9 changes: 7 additions & 2 deletions backend/cdkproxymain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def up(response: Response):
def check_creds(response: Response):
logger.info('GET /awscreds')
try:
sts = boto3.client('sts', region_name=os.getenv('AWS_REGION', 'eu-west-1'))
region = os.getenv('AWS_REGION', 'eu-west-1')
sts = boto3.client(
'sts',
region_name=region,
endpoint_url=f"https://sts.{region}.amazonaws.com"
)
data = sts.get_caller_identity()
return {
'DH_DOCKER_VERSION': os.environ.get('DH_DOCKER_VERSION'),
Expand Down Expand Up @@ -84,7 +89,7 @@ def check_connect(response: Response):
return {
'DH_DOCKER_VERSION': os.environ.get('DH_DOCKER_VERSION'),
'_ts': datetime.now().isoformat(),
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.params['host']}:{engine.dbconfig.params['port']})",
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.host})",
}
except Exception as e:
logger.exception('DBCONNECTIONERROR')
Expand Down
3 changes: 2 additions & 1 deletion backend/dataall/api/Objects/Dataset/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class DatasetSortField(GraphQLEnumMapper):
gql.Argument('language', gql.Ref('Language')),
gql.Argument('confidentiality', gql.Ref('ConfidentialityClassification')),
gql.Argument(name='stewards', type=gql.String),
gql.Argument('KmsAlias', gql.NonNullableType(gql.String)),
],
)

Expand Down Expand Up @@ -94,7 +95,7 @@ class DatasetSortField(GraphQLEnumMapper):
gql.Argument('description', gql.String),
gql.Argument('bucketName', gql.NonNullableType(gql.String)),
gql.Argument('glueDatabaseName', gql.String),
gql.Argument('KmsKeyId', gql.String),
gql.Argument('KmsKeyAlias', gql.NonNullableType(gql.String)),
gql.Argument('adminRoleName', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('owner', gql.NonNullableType(gql.String)),
Expand Down
24 changes: 21 additions & 3 deletions backend/dataall/api/Objects/Dataset/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ....aws.handlers.glue import Glue
from ....aws.handlers.service_handlers import Worker
from ....aws.handlers.sts import SessionHelper
from ....aws.handlers.sns import Sns
from ....aws.handlers.kms import KMS

from ....aws.handlers.quicksight import Quicksight
from ....db import paginate, exceptions, permissions, models
from ....db.api import Dataset, Environment, ShareObject, ResourcePolicy
Expand All @@ -32,6 +33,21 @@ def check_dataset_account(environment):
return True


def check_imported_resources(environment, kmsAlias):
if kmsAlias not in ["Undefined", "", "SSE-S3"]:
key_id = KMS.get_key_id(
account_id=environment.AwsAccountId,
region=environment.region,
key_alias=f"alias/{kmsAlias}"
)
if not key_id:
raise exceptions.AWSResourceNotFound(
action=permissions.IMPORT_DATASET,
message=f'KMS key with alias={kmsAlias} cannot be found',
)
return True


def create_dataset(context: Context, source, input=None):
with context.engine.scoped_session() as session:
environment = Environment.get_environment_by_uri(session, input.get('environmentUri'))
Expand Down Expand Up @@ -71,6 +87,7 @@ def import_dataset(context: Context, source, input=None):
with context.engine.scoped_session() as session:
environment = Environment.get_environment_by_uri(session, input.get('environmentUri'))
check_dataset_account(environment=environment)
check_imported_resources(environment=environment, kmsAlias=input.get('KmsKeyAlias', ""))

dataset = Dataset.create_dataset(
session=session,
Expand All @@ -83,9 +100,9 @@ def import_dataset(context: Context, source, input=None):
dataset.imported = True
dataset.importedS3Bucket = True if input['bucketName'] else False
dataset.importedGlueDatabase = True if input.get('glueDatabaseName') else False
dataset.importedKmsKey = True if input.get('KmsKeyId') else False
dataset.importedKmsKey = True if input.get('KmsKeyAlias') else False
dataset.importedAdminRole = True if input.get('adminRoleName') else False

dataset.KmsAlias = "SSE-S3" if input.get('KmsKeyAlias') == "" else input.get('KmsKeyAlias')
Dataset.create_dataset_stack(session, dataset)

indexers.upsert_dataset(
Expand Down Expand Up @@ -231,6 +248,7 @@ def update_dataset(context, source, datasetUri: str = None, input: dict = None):
dataset = Dataset.get_dataset_by_uri(session, datasetUri)
environment = Environment.get_environment_by_uri(session, dataset.environmentUri)
check_dataset_account(environment=environment)
check_imported_resources(environment=environment, kmsAlias=input.get('KmsAlias', ""))
updated_dataset = Dataset.update_dataset(
session=session,
username=context.username,
Expand Down
10 changes: 0 additions & 10 deletions backend/dataall/api/Objects/DatasetProfiling/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,3 @@
type=gql.Ref('DatasetProfilingRun'),
resolver=start_profiling_run,
)

updateDatasetProfilingRunResults = gql.MutationField(
name='updateDatasetProfilingRunResults',
args=[
gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='results', type=gql.NonNullableType(gql.String)),
],
type=gql.Ref('DatasetProfilingRun'),
resolver=update_profiling_run_results,
)
20 changes: 1 addition & 19 deletions backend/dataall/api/Objects/DatasetProfiling/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,6 @@
from .resolvers import *


getDatasetProfilingRun = gql.QueryField(
name='getDatasetProfilingRun',
args=[gql.Argument(name='profilingRunUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('DatasetProfilingRun'),
resolver=get_profiling_run,
)


listDatasetProfilingRuns = gql.QueryField(
name='listDatasetProfilingRuns',
args=[
gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='filter', type=gql.Ref('DatasetProfilingRunFilter')),
],
type=gql.Ref('DatasetProfilingRunSearchResults'),
resolver=list_profiling_runs,
)

listDatasetTableProfilingRuns = gql.QueryField(
name='listDatasetTableProfilingRuns',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
Expand All @@ -31,5 +13,5 @@
name='getDatasetTableProfilingRun',
args=[gql.Argument(name='tableUri', type=gql.NonNullableType(gql.String))],
type=gql.Ref('DatasetProfilingRun'),
resolver=get_last_table_profiling_run,
resolver=get_dataset_table_profiling_run,
)
100 changes: 58 additions & 42 deletions backend/dataall/api/Objects/DatasetProfiling/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging

from .... import db
from ....api.context import Context
from ....aws.handlers.service_handlers import Worker
from ....aws.handlers.sts import SessionHelper
Expand All @@ -19,7 +20,30 @@ def resolve_dataset(context, source: models.DatasetProfilingRun):
)


def resolve_profiling_run_status(context: Context, source: models.DatasetProfilingRun):
if not source:
return None
with context.engine.scoped_session() as session:
task = models.Task(
targetUri=source.profilingRunUri, action='glue.job.profiling_run_status'
)
session.add(task)
Worker.queue(engine=context.engine, task_ids=[task.taskUri])
return source.status


def resolve_profiling_results(context: Context, source: models.DatasetProfilingRun):
if not source or source.results == {}:
return None
else:
return json.dumps(source.results)


def start_profiling_run(context: Context, source, input: dict = None):
"""
Triggers profiling jobs on a Table.
Only Dataset owners with PROFILE_DATASET_TABLE can perform this action
"""
with context.engine.scoped_session() as session:

ResourcePolicy.check_user_resource_permission(
Expand Down Expand Up @@ -48,47 +72,14 @@ def start_profiling_run(context: Context, source, input: dict = None):
return run


def get_profiling_run_status(context: Context, source: models.DatasetProfilingRun):
if not source:
return None
with context.engine.scoped_session() as session:
task = models.Task(
targetUri=source.profilingRunUri, action='glue.job.profiling_run_status'
)
session.add(task)
Worker.queue(engine=context.engine, task_ids=[task.taskUri])
return source.status


def get_profiling_results(context: Context, source: models.DatasetProfilingRun):
if not source or source.results == {}:
return None
else:
return json.dumps(source.results)


def update_profiling_run_results(context: Context, source, profilingRunUri, results):
with context.engine.scoped_session() as session:
run = api.DatasetProfilingRun.update_run(
session=session, profilingRunUri=profilingRunUri, results=results
)
return run


def list_profiling_runs(context: Context, source, datasetUri=None):
with context.engine.scoped_session() as session:
return api.DatasetProfilingRun.list_profiling_runs(session, datasetUri)


def get_profiling_run(context: Context, source, profilingRunUri=None):
with context.engine.scoped_session() as session:
return api.DatasetProfilingRun.get_profiling_run(
session=session, profilingRunUri=profilingRunUri
)


def get_last_table_profiling_run(context: Context, source, tableUri=None):
def get_dataset_table_profiling_run(context: Context, source, tableUri=None):
"""
Shows the results of the last profiling job on a Table.
For datasets "Unclassified" all users can perform this action.
For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action.
"""
with context.engine.scoped_session() as session:
_check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri)
run: models.DatasetProfilingRun = (
api.DatasetProfilingRun.get_table_last_profiling_run(
session=session, tableUri=tableUri
Expand All @@ -102,7 +93,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None):
environment = api.Environment.get_environment_by_uri(
session, dataset.environmentUri
)
content = get_profiling_results_from_s3(
content = _get_profiling_results_from_s3(
environment, dataset, table, run
)
if content:
Expand All @@ -121,7 +112,7 @@ def get_last_table_profiling_run(context: Context, source, tableUri=None):
return run


def get_profiling_results_from_s3(environment, dataset, table, run):
def _get_profiling_results_from_s3(environment, dataset, table, run):
s3 = SessionHelper.remote_session(environment.AwsAccountId).client(
's3', region_name=environment.region
)
Expand All @@ -141,7 +132,32 @@ def get_profiling_results_from_s3(environment, dataset, table, run):


def list_table_profiling_runs(context: Context, source, tableUri=None):
"""
Lists the runs of a profiling job on a Table.
For datasets "Unclassified" all users can perform this action.
For datasets "Secret" or "Official", only users with PREVIEW_DATASET_TABLE permissions can perform this action.
"""
with context.engine.scoped_session() as session:
_check_preview_permissions_if_needed(context=context, session=session, tableUri=tableUri)
return api.DatasetProfilingRun.list_table_profiling_runs(
session=session, tableUri=tableUri, filter={}
)


def _check_preview_permissions_if_needed(context, session, tableUri):
table: models.DatasetTable = db.api.DatasetTable.get_dataset_table_by_uri(
session, tableUri
)
dataset = db.api.Dataset.get_dataset_by_uri(session, table.datasetUri)
if (
dataset.confidentiality
!= models.ConfidentialityClassification.Unclassified.value
):
ResourcePolicy.check_user_resource_permission(
session=session,
username=context.username,
groups=context.groups,
resource_uri=table.tableUri,
permission_name=permissions.PREVIEW_DATASET_TABLE,
)
return True
8 changes: 4 additions & 4 deletions backend/dataall/api/Objects/DatasetProfiling/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ... import gql
from .resolvers import (
resolve_dataset,
get_profiling_run_status,
get_profiling_results,
resolve_profiling_run_status,
resolve_profiling_results,
)

DatasetProfilingRun = gql.ObjectType(
Expand All @@ -16,11 +16,11 @@
gql.Field(name='GlueTriggerName', type=gql.String),
gql.Field(name='GlueTableName', type=gql.String),
gql.Field(name='AwsAccountId', type=gql.String),
gql.Field(name='results', type=gql.String, resolver=get_profiling_results),
gql.Field(name='results', type=gql.String, resolver=resolve_profiling_results),
gql.Field(name='created', type=gql.String),
gql.Field(name='updated', type=gql.String),
gql.Field(name='owner', type=gql.String),
gql.Field('status', type=gql.String, resolver=get_profiling_run_status),
gql.Field('status', type=gql.String, resolver=resolve_profiling_run_status),
gql.Field(name='dataset', type=gql.Ref('Dataset'), resolver=resolve_dataset),
],
)
Expand Down
8 changes: 8 additions & 0 deletions backend/dataall/api/Objects/Environment/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,14 @@
test_scope='Environment',
)

getCDKExecPolicyPresignedUrl = gql.QueryField(
name='getCDKExecPolicyPresignedUrl',
args=[gql.Argument(name='organizationUri', type=gql.NonNullableType(gql.String))],
type=gql.String,
resolver=get_cdk_exec_policy_template,
test_scope='Environment',
)


getPivotRoleExternalId = gql.QueryField(
name='getPivotRoleExternalId',
Expand Down
Loading

0 comments on commit 9a51757

Please sign in to comment.