Skip to content

Commit

Permalink
Use jmespath instead of loops, simplify UsagesClass.
Browse files Browse the repository at this point in the history
  • Loading branch information
cerrussell committed Jan 15, 2024
1 parent a30b8e8 commit 5e473a4
Show file tree
Hide file tree
Showing 7 changed files with 61,644 additions and 207 deletions.
32 changes: 20 additions & 12 deletions atom_tools/cli/commands/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,26 @@ class ConvertCommand(Command):
flag=False,
default=None
),
option(
'reachables-slice',
'r',
'Reachables slice file',
flag=False,
default=None,
),
# option(
# 'reachables-slice',
# 'r',
# 'Reachables slice file',
# flag=False,
# default=None,
# ),
option(
'language',
'l',
'',
flag=False,
default='Java',
default='java',
),
option(
'output-file',
'o',
'Output file',
flag=False,
default='openapi_from_slice.json',
)
]
help = """The convert command converts an atom slice to a different format.
Expand All @@ -121,15 +128,16 @@ def handle(self):
Executes the convert command and performs the conversion.
"""
match self.option('format'):
case 'openapi3.1.0':
case 'openapi3.1.0' | 'openapi3.0.1':
converter_instance = OpenAPI(
self.option('format'),
self.option('language'),
self.option('usages-slice'),
self.option('reachables-slice'),
# self.option('reachables-slice'),
)
if result := converter_instance.convert_slices():
with open('output.json', 'w', encoding='utf-8') as f:
if result := converter_instance.endpoints_to_openapi():
with open(self.option('output-file'), 'w',
encoding='utf-8') as f:
json.dump(result, f, indent=4)
case _:
raise ValueError(
Expand Down
53 changes: 11 additions & 42 deletions atom_tools/lib/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,58 +75,27 @@ def create_paths_item(self):
"""
raise NotImplementedError

def convert_slices(self):
def endpoints_to_openapi(self):
"""
Converts slices.
This function converts available slices
Returns:
The converted slices, or None if no slices are available.
"""
if self.usages and self.reachables:
return self.combine_converted(self.usages, self.reachables)
if self.usages:
return self.convert_usages()
return self.convert_reachables() if self.reachables else None

def combine_converted(self, usages, reachables):
"""
Combines converted usages and reachables into a single document.
"""
raise NotImplementedError

def convert_usages(self):
Combines usages and reachables endpoints into a single document.
"""
Converts usages to OpenAPI.
"""
endpoints = sorted(set(self.usages.generate_endpoints()))
paths_obj = {endpoint: {} for endpoint in endpoints}
endpoints = self.convert_usages()
paths_obj = {r: {} for r in endpoints} or {}
return {
'openapi': self.openapi_version,
'info': {'title': 'Atom Usages', 'version': '1.0.0'},
'paths': paths_obj
}

def convert_usages(self):
"""
Converts usages to OpenAPI.
"""
return sorted(
set(self.usages.generate_endpoints())) if self.usages else []

def convert_reachables(self):
"""
Converts reachables to OpenAPI.
"""
raise NotImplementedError
# endpoints = self.bom.generate_endpoints()
# with open(self.reachables_file, 'r', encoding='utf-8') as f:
# data = json.load(f).get('reachables')
#
# with open('schemas/template.json', 'r', encoding='utf-8') as f:
# template = json.load(f)
#
# paths = []
# reached = {}
#
# for endpoint in endpoints:
# new_path = {
# endpoint: {
# 'x-atom-reachables': {}
# }
# }
138 changes: 44 additions & 94 deletions atom_tools/lib/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,29 @@
import logging
import os.path
import re
import jmespath


class UsageSlice:
"""
Represents a usage slice.
This class is responsible for importing and storing usage slices.
Args:
filename (str): The path to the JSON file.
Attributes:
object_slices (list): A list of object slices.
user_defined_types (list): A list of user-defined types.
usages (dict): The dictionary loaded from the usages JSON file.
Methods:
import_slice: Imports a slice from a JSON file.
generate_endpoints: Generates a list of endpoints from a slice.
extract_endpoints: Extracts a list of endpoints from a usage.
"""
ENDPOINTS_REGEX = re.compile(r'[\'"](\S*?)[\'"]', re.IGNORECASE)

def __init__(self, filename, language):
[self.object_slices, self.user_defined_types] = self.import_slice(
filename
)
self.usages = self.import_slice(filename)
self.language = language
self.endpoints_regex = re.compile(r'[\'"](\S*?)[\'"]', re.IGNORECASE)

@staticmethod
def import_slice(filename):
Expand All @@ -42,9 +39,7 @@ def import_slice(filename):
filename (str): The path to the JSON file.
Returns:
tuple: A tuple containing the object slices and user-defined types.
The object slices is a list of object slices.
The user-defined types is a list of user-defined types.
dict: The contents of the JSON file.
Raises:
JSONDecodeError: If the JSON file cannot be decoded.
Expand All @@ -55,46 +50,53 @@ def import_slice(filename):
If the JSON file is not a valid usage slice, a warning is logged.
"""
if not filename:
return [], []
return {}
try:
with open(filename, 'r', encoding='utf-8') as f:
content = json.load(f)
if content.get('objectSlices'):
return content.get('objectSlices', []), content.get(
'userDefinedTypes', []
)
# except [json.decoder.JSONDecodeError, UnicodeDecodeError]:
# logging.warning(
# f'Failed to load usages slice: {filename}\nPlease check '
# f'that you specified a valid json file.'
# )
return content
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
logging.warning(
f'Failed to load usages slice: {filename}\nPlease check '
f'that you specified a valid json file.')
except FileNotFoundError:
logging.warning(
f'Failed to locate the usages slice file in the location '
f'specified: {filename}'
)
f'specified: {filename}')

logging.warning(
f'This does not appear to be a valid usage slice: '
f'{filename}\nPlease check that you specified the '
f'correct usages slice file.'
)
return [], []
logging.warning(f'This does not appear to be a valid usage slice: '
f'{filename}\nPlease check that you specified the '
f'correct usages slice file.')
return {}

def generate_endpoints(self):
"""
Generates and returns a dictionary of endpoints based on the object
slices and user-defined types (UDTs).
Returns:
list: A list of endpoints.
list: A list of unique endpoints.
"""
# Surely there is a way to combine these...
target_obj_pattern = jmespath.compile(
'objectSlices[].usages[].targetObj.resolvedMethod')
defined_by_pattern = jmespath.compile(
'objectSlices[].usages[].definedBy.resolvedMethod')
invoked_calls_pattern = jmespath.compile(
'objectSlices[].usages[].invokedCalls[].resolvedMethod')
udt_jmespath_query = jmespath.compile(
'userDefinedTypes[].fields[].name')
methods = target_obj_pattern.search(self.usages) or []
methods.extend(defined_by_pattern.search(self.usages) or [])
methods.extend(invoked_calls_pattern.search(self.usages) or [])
methods.extend(udt_jmespath_query.search(self.usages) or [])
methods = list(set(methods))

endpoints = []
for object_slice in self.object_slices:
endpoints.extend(self.extract_endpoints_from_usages(
object_slice.get('usages', []))
)
endpoints.extend(self.extract_endpoints_from_udts())
if methods:
for method in methods:
endpoints.extend(self.extract_endpoints(method))

return list(set(endpoints))

def extract_endpoints(self, code):
Expand All @@ -103,15 +105,17 @@ def extract_endpoints(self, code):
Args:
code (str): The code from which to extract endpoints.
pkg (str): The package name to prepend to the extracted endpoints.
Returns:
list: A list of extracted endpoints.
Raises:
None.
"""
endpoints = []
if not code:
return endpoints
matches = re.findall(UsageSlice.ENDPOINTS_REGEX, code) or []
matches = re.findall(self.endpoints_regex, code) or []
match self.language:
case 'java' | 'jar':
if code.startswith('@') and (
Expand Down Expand Up @@ -149,52 +153,9 @@ def extract_endpoints(self, code):
])
return endpoints

def extract_endpoints_from_usages(self, usages):
"""
Extracts endpoints from the given list of usages.
Args:
usages (List[Dict]): A list of dicts representing the usages.
pkg (str): The package name.
Returns:
List: A list of extracted endpoints.
"""
endpoints = []
for usage in usages:
target_obj = usage.get('targetObj', {})
defined_by = usage.get('definedBy', {})
invoked_calls = usage.get('invokedCalls', [])
if resolved_method := target_obj.get('resolvedMethod'):
endpoints.extend(self.extract_endpoints(resolved_method))
elif resolved_method := defined_by.get('resolvedMethod'):
endpoints.extend(self.extract_endpoints(resolved_method))
if invoked_calls:
for call in invoked_calls:
if resolved_method := call.get('resolvedMethod'):
endpoints.extend(
self.extract_endpoints(resolved_method))
return endpoints

def extract_endpoints_from_udts(self):
"""
Extracts endpoints from user-defined types.
Returns:
list: A list of endpoints extracted from the user-defined types.
"""
endpoints = []
for udt in self.user_defined_types:
if fields := udt.get('fields'):
for f in fields:
endpoints.extend(self.extract_endpoints(f.get('name')))
return endpoints


class ReachablesSlice:
"""
Represents a slice of reachables.
This class is responsible for importing and storing reachables slices.
Args:
Expand All @@ -220,9 +181,7 @@ def import_slice(filename):
filename (str): The path to the JSON file.
Returns:
tuple: A tuple containing the object slices and user-defined types.
The object slices is a list of object slices.
The user-defined types is a list of user-defined types.
list: A list of the reachables slices.
Raises:
JSONDecodeError: If the JSON file cannot be decoded.
Expand All @@ -238,7 +197,7 @@ def import_slice(filename):
with open(filename, 'r', encoding='utf-8') as f:
content = json.load(f)
return content.get('reachables', [])
except [json.decoder.JSONDecodeError, UnicodeDecodeError]:
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
logging.warning(
f'Failed to load usages slice: {filename}\nPlease check '
f'that you specified a valid json file.'
Expand All @@ -254,12 +213,3 @@ def import_slice(filename):
f'Please check that you specified the correct usages slice file.'
)
return []

@property
def slices(self):
"""
Get reachables
Returns:
A list of reachables.
"""
return self.reachables
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "Collection of tools for use with appthreat/atom."
authors = [
{ name = "Team AppThreat", email = "[email protected]" },
]
dependencies = ["cleo>=1.0.0"]
dependencies = ["cleo>=1.0.0", "jmespath"]
maintainers = [
{ name = "Caroline Russell", email = "[email protected]" },
]
Expand Down Expand Up @@ -51,7 +51,6 @@ build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = [
"atom_tools",
"atom_tools.atom_data",
"atom_tools.cli",
"atom_tools.cli.commands",
"atom_tools.lib"
Expand Down
Loading

0 comments on commit 5e473a4

Please sign in to comment.