Skip to content

Commit

Permalink
refactoring cur
Browse files Browse the repository at this point in the history
  • Loading branch information
iakov-aws committed Dec 24, 2023
1 parent 6d8f314 commit 900a42a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 81 deletions.
24 changes: 12 additions & 12 deletions cid/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def cur(self) -> CUR:
if not _cur.configured:
raise CidCritical("Error: please ensure CUR is enabled, if yes allow it some time to propagate")

print(f'\tAthena table: {_cur.tableName}')
print(f"\tResource IDs: {'yes' if _cur.hasResourceIDs else 'no'}")
if not _cur.hasResourceIDs:
print(f'\tAthena table: {_cur.table_name}')
print(f"\tResource IDs: {'yes' if _cur.has_resource_ids else 'no'}")
if not _cur.has_resource_ids:
raise CidCritical("Error: CUR has to be created with Resource IDs")
print(f"\tSavingsPlans: {'yes' if _cur.hasSavingsPlans else 'no'}")
print(f"\tReserved Instances: {'yes' if _cur.hasReservations else 'no'}")
print(f"\tSavingsPlans: {'yes' if _cur.has_savings_plans else 'no'}")
print(f"\tReserved Instances: {'yes' if _cur.has_reservations else 'no'}")
print('\n')
self._clients.update({
'cur': _cur
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non

# Check for required views
_views = dataset_definition.get('dependsOn', {}).get('views', [])
required_views = [(self.cur.tableName if cur_required and name =='${cur_table_name}' else name) for name in _views]
required_views = [(self.cur.table_name if cur_required and name =='${cur_table_name}' else name) for name in _views]

self.athena.discover_views(required_views)
found_views = utils.intersection(required_views, self.athena._metadata.keys())
Expand All @@ -1370,7 +1370,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
if recursive:
print(f"Detected views: {', '.join(found_views)}")
for view_name in found_views:
if cur_required and view_name == self.cur.tableName:
if cur_required and view_name == self.cur.table_name:
logger.debug(f'Dependancy view {view_name} is a CUR. Skip.')
continue
if view_name == 'account_map':
Expand All @@ -1389,7 +1389,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
columns_tpl = {
'athena_datasource_arn': athena_datasource.arn,
'athena_database_name': self.athena.DatabaseName,
'cur_table_name': self.cur.tableName if cur_required else None
'cur_table_name': self.cur.table_name if cur_required else None
}

logger.debug(f'dataset_id={dataset_id}')
Expand Down Expand Up @@ -1600,11 +1600,11 @@ def get_view_query(self, view_name: str) -> str:
# View path
view_definition = self.get_definition("view", name=view_name)
cur_required = view_definition.get('dependsOn', dict()).get('cur')
if cur_required and self.cur.hasSavingsPlans and self.cur.hasReservations and view_definition.get('spriFile'):
if cur_required and self.cur.has_savings_plans and self.cur.has_reservations and view_definition.get('spriFile'):
view_definition['File'] = view_definition.get('spriFile')
elif cur_required and self.cur.hasSavingsPlans and view_definition.get('spFile'):
elif cur_required and self.cur.has_savings_plans and view_definition.get('spFile'):
view_definition['File'] = view_definition.get('spFile')
elif cur_required and self.cur.hasReservations and view_definition.get('riFile'):
elif cur_required and self.cur.has_reservations and view_definition.get('riFile'):
view_definition['File'] = view_definition.get('riFile')
elif view_definition.get('File') or view_definition.get('Data') or view_definition.get('data'):
pass
Expand All @@ -1621,7 +1621,7 @@ def get_view_query(self, view_name: str) -> str:

# Prepare template parameters
columns_tpl = {
'cur_table_name': self.cur.tableName if cur_required else None,
'cur_table_name': self.cur.table_name if cur_required else None,
'athenaTableName': view_name,
'athena_database_name': self.athena.DatabaseName,
}
Expand Down
4 changes: 2 additions & 2 deletions cid/helpers/account_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def create(self, name) -> bool:
# Fill in TPLs
columns_tpl = {
'metadata_table_name': self._athena_table_name,
'cur_table_name': self.cur.tableName # only for trends
'cur_table_name': self.cur.table_name # only for trends
}
for key, val in self.mappings.get(name).get(self._athena_table_name).items():
logger.debug(f'Mapping field {key} to {val}')
Expand All @@ -185,7 +185,7 @@ def get_dummy_account_mapping_sql(self, name) -> list:
).decode('utf-8'))
columns_tpl = {
'athena_view_name': name,
'cur_table_name': self.cur.tableName
'cur_table_name': self.cur.table_name
}
compiled_query = template.safe_substitute(columns_tpl)
return compiled_query
Expand Down
123 changes: 56 additions & 67 deletions cid/helpers/cur.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Manage AWS CUR
"""
import json
import logging

Expand All @@ -10,6 +12,8 @@


class CUR(CidBase):
""" Manage AWS CUR
"""
cur_minimal_required_columns = [
'bill_bill_type',
'bill_billing_entity',
Expand All @@ -25,7 +29,6 @@ class CUR(CidBase):
'line_item_operation',
'line_item_product_code',
'line_item_resource_id',
#'line_item_resource_id',
'line_item_unblended_cost',
'line_item_usage_account_id',
'line_item_usage_amount',
Expand All @@ -34,20 +37,6 @@ class CUR(CidBase):
'line_item_usage_type',
'pricing_term',
'pricing_unit',
# 'product_database_engine',
# 'product_deployment_option',
# 'product_from_location',
# 'product_group',
# 'product_instance_type',
# 'product_instance_type_family',
# 'product_operating_system',
# 'product_product_family',
# 'product_product_name',
# 'product_region',
# 'product_servicecode',
# 'product_storage',
# 'product_to_location',
# 'product_volume_api_name',
]
ri_required_columns = [
'reservation_reservation_a_r_n',
Expand All @@ -67,43 +56,40 @@ class CUR(CidBase):
'savings_plan_offering_type',
'savings_plan_payment_option'
]
_tableName = None
_table_name = None
_metadata = None
_clients = dict()
_hasResourceIDs = None
_hasSavingsPlans = None
_hasReservations = None
_clients = {}
_has_resource_ids = None
_has_savings_plans = None
_has_reservations = None
_configured = None
_status = str()
_status = {}


def __init__(self, session) -> None:
super().__init__(session)

@property
def athena(self) -> Athena:
""" Get Athena Client """
if not self._clients.get('athena'):
self._clients.update({
'athena': Athena(self.session)
})
self._clients['athena'] = Athena(self.session)
return self._clients.get('athena')

@athena.setter
def athena(self, client) -> Athena:
""" Set Athena Client """
if not self._clients.get('athena'):
self._clients.update({
'athena': client
})
self._clients['athena'] = client
return self._clients.get('athena')

@property
def glue(self) -> Glue:
""" Get Glue Client """
if not self._clients.get('glue'):
self._clients['glue'] = Glue(self.session)
return self._clients.get('glue')

Check notice

Code scanning / CodeGuru Reviewer Scanner

Risky use of dict get method Low

You are using the get method without a default argument to return the value of a key in a dictionary. We recommended that you use a default argument so that if the value for your key is not found, a default value is returned. If a default value is not provided and the key is not found, then None is returned.

Learn more

@glue.setter
def glue(self, client) -> Glue:
""" Set Glue client """
if not self._clients.get('glue'):
self._clients['glue'] = client
return self._clients.get('glue')
Expand All @@ -112,39 +98,40 @@ def glue(self, client) -> Glue:
def configured(self) -> bool:
""" Check if AWS Data Catalog and Athena database exist """
if self._configured is None:
if self.athena.CatalogName and self.athena.DatabaseName:
self._configured = True
else:
self._configured = False
self._configured = bool(self.athena.CatalogName and self.athena.DatabaseName)
return self._configured

@property
def tableName(self) -> str:
def table_name(self) -> str:
""" Get Athena table name """
if self.metadata is None:
raise CidCritical('Error: Cannot detect any CUR table. Hint: Check if AWS Lake Formation is activated on your account, verify that the LakeFormationEnabled parameter is set to yes on the deployment stack')
return self.metadata.get('Name')

@property
def hasResourceIDs(self) -> bool:
if self._configured and self._hasResourceIDs is None:
self._hasResourceIDs = 'line_item_resource_id' in self.fields
return self._hasResourceIDs
def has_resource_ids(self) -> bool:
""" Return True if CUR has resource ids """
if self._configured and self._has_resource_ids is None:
self._has_resource_ids = 'line_item_resource_id' in self.fields
return self._has_resource_ids

@property
def hasReservations(self) -> bool:
if self._configured and self._hasReservations is None:
logger.debug(f'{self.ri_required_columns}: {[c in self.fields for c in self.ri_required_columns]}')
self._hasReservations=all([c in self.fields for c in self.ri_required_columns])
logger.info(f'Reserved Instances: {self._hasReservations}')
return self._hasReservations
def has_reservations(self) -> bool:
""" Return True if CUR has reservation fields """
if self._configured and self._has_reservations is None:
logger.debug(f'{self.ri_required_columns}: {[col in self.fields for col in self.ri_required_columns]}')
self._has_reservations = all(col in self.fields for col in self.ri_required_columns)
logger.info(f'Reserved Instances: {self._has_reservations}')
return self._has_reservations

@property
def hasSavingsPlans(self) -> bool:
if self._configured and self._hasSavingsPlans is None:
logger.debug(f'{self.sp_required_columns}: {[c in self.fields for c in self.sp_required_columns]}')
self._hasSavingsPlans=all([c in self.fields for c in self.sp_required_columns])
logger.info(f'Savings Plans: {self._hasSavingsPlans}')
return self._hasSavingsPlans
def has_savings_plans(self) -> bool:
""" Return True if CUR has savings plan """
if self._configured and self._has_savings_plans is None:
logger.debug(f'{self.sp_required_columns}: {[col in self.fields for col in self.sp_required_columns]}')
self._has_savings_plans=all(col in self.fields for col in self.sp_required_columns)
logger.info(f'Savings Plans: {self._has_savings_plans}')
return self._has_savings_plans

def get_type_of_column(self, column: str):
""" Return an Athena type of a given non existent CUR column """
Expand All @@ -155,7 +142,7 @@ def get_type_of_column(self, column: str):
return 'DOUBLE'
if column.endswith('_date') and not column.endswith('_to_date'):
return 'TIMESTAMP'
SPECIAL = {
special_cases = {
"reservation_amortized_upfront_cost_for_usage": "DOUBLE",
"reservation_amortized_upfront_fee_for_billing_period": "DOUBLE",
"reservation_recurring_fee_for_usage": "DOUBLE",
Expand All @@ -173,7 +160,7 @@ def get_type_of_column(self, column: str):
"savings_plan_net_amortized_upfront_commitment_for_billing_period": "DOUBLE",
"savings_plan_recurring_commitment_for_billing_period": "DOUBLE",
}
return SPECIAL.get(column, 'STRING')
return special_cases.get(column, 'STRING')

def ensure_column(self, column: str, column_type: str=None):
""" Ensure column is in the cur. If it is not there - add column """
Expand All @@ -191,22 +178,22 @@ def ensure_column(self, column: str, column_type: str=None):

column_type = column_type or self.get_type_of_column(column)
try:
self.athena.query(f'ALTER TABLE {self._tableName} ADD COLUMNS ({column} {column_type})')
self.athena.query(f'ALTER TABLE {self._table_name} ADD COLUMNS ({column} {column_type})')
except (self.athena.client.exceptions.ClientError, CidCritical) as exc:
raise CidCritical(f'Column {column} is not found in CUR and we were unable to add it. Please check FAQ.') from exc
self._metadata = self.athena.get_table_metadata(self._tableName) # refresh table metadata
logger.critical(f"Column '{column}' was added to CUR ({self._tableName}). Please make sure crawler do not override that columns. Crawler='{crawler_name}'")
self._metadata = self.athena.get_table_metadata(self._table_name) # refresh table metadata
logger.critical(f"Column '{column}' was added to CUR ({self._table_name}). Please make sure crawler do not override that columns. Crawler='{crawler_name}'")

def table_is_cur(self, table: dict=None, name: str=None, return_reason: bool=False) -> bool:
""" return True if table metadata fits CUR definition. """
try:
table = table or self.athena.get_table_metadata(name)
except Exception as exc:
except Exception as exc: #pylint: disable=broad-exception-caught
logger.debug(exc)
return False if not return_reason else (False, f'cannot get table {name}. {exc}.')

table_name = table.get('Name')
columns = [cols.get('Name') for cols in table.get('Columns')]
columns = [col.get('Name') for col in table.get('Columns')]
missing_columns = [col for col in self.cur_minimal_required_columns if col not in columns]
if missing_columns:
return False if not return_reason else (False, f"Table {table_name} does not contain columns: {','.join(missing_columns)}. You can try ALTER TABLE {table_name} ADD COLUMNS (missing_column string).")
Expand All @@ -215,18 +202,19 @@ def table_is_cur(self, table: dict=None, name: str=None, return_reason: bool=Fal

@property
def metadata(self) -> dict:
"""get Athena metadata for the table of CUR """
if self._metadata:
return self._metadata

if get_parameters().get('cur-table-name'):
self._tableName = get_parameters().get('cur-table-name')
self._table_name = get_parameters().get('cur-table-name')
try:
self._metadata = self.athena.get_table_metadata(self._tableName)
self._metadata = self.athena.get_table_metadata(self._table_name)
except self.athena.client.exceptions.ResourceNotFoundException as exc:
raise CidCritical('Provided cur-table-name "{self._tableName}" is not found. Please make sure the table exists.') from exc
raise CidCritical('Provided cur-table-name "{self._table_name}" is not found. Please make sure the table exists.') from exc
res, message = self.table_is_cur(table=self._metadata, return_reason=True)
if not res:
raise CidCritical(f'Table {self._tableName} does not look like CUR. {message}')
raise CidCritical(f'Table {self._table_name} does not look like CUR. {message}')
else:
# Look all tables and filter ones with CUR fields
all_tables = self.athena.list_table_metadata()
Expand All @@ -241,20 +229,21 @@ def metadata(self) -> dict:
raise CidCritical(f'CUR table not found. (scanned {len(all_tables)} tables in Athena Database {self.athena.DatabaseName} in {self.athena.region}). But none has required fields: {self.cur_minimal_required_columns}.')
if len(cur_tables) == 1:
self._metadata = cur_tables[0]
self._tableName = self._metadata.get('Name')
logger.info('1 CUR table found: %s', self._tableName)
self._table_name = self._metadata.get('Name')
logger.info('1 CUR table found: %s', self._table_name)
elif len(cur_tables) > 1:
self._tableName = get_parameter(
self._table_name = get_parameter(
param_name='cur-table-name',
message="Multiple CUR tables found, please select one",
choices=sorted([v.get('Name') for v in cur_tables], reverse=True),
)
self._metadata = self.athena.get_table_metadata(self._tableName)
self._metadata = self.athena.get_table_metadata(self._table_name)
return self._metadata

@property
def fields(self) -> list:
return [v.get('Name') for v in self.metadata.get('Columns', list())]
"""get CUR fields """
return [col.get('Name') for col in self.metadata.get('Columns', [])]

@property
def tag_and_cost_category_fields(self) -> list:
Expand Down

0 comments on commit 900a42a

Please sign in to comment.