Skip to content
This repository was archived by the owner on Oct 6, 2021. It is now read-only.

KeyError fix on tram launch; ML Service class tidy-up #61

Closed
wants to merge 11 commits into from
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Python cache
__pycache__/
# Local database
database/tram.db

# IDEs
.idea/
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Threat Report ATT&CK<sup>®</sup> Mapping (TRAM) is a tool to aid analyst in map
- Google Chrome is our only supported/tested browser

## Installation

Please note: if your environment has multiple Python interpreters (e.g. `python` is for Python 2.x and `python3` is for Python 3.x, please adjust some of the commands below accordingly. For example, `pip` may be `python3 -m pip install ...` and `python tram.py` may be `python3 tram.py`).

Start by cloning this repository.
```
git clone https://github.com/mitre-attack/tram.git
Expand Down Expand Up @@ -53,4 +56,4 @@ limitations under the License.

This project makes use of ATT&CK®

ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/
ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/
5 changes: 4 additions & 1 deletion conf/config.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
---

# Specified host and port where tram will run
host: 0.0.0.0
port: 9999
# Either 'taxii-server' for latest data or 'local-json' to use local data (data from json_file)
taxii-local: taxii-server
# If you would like the database to be re-built on launch of tram
build: True
# The JSON file containing attack data; ensure file is in /models directory
json_file: enterprise-attack.json

7 changes: 1 addition & 6 deletions handlers/web_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from aiohttp_jinja2 import template, web
import nltk
import json


class WebAPI:

def __init__(self, services):

self.dao = services.get('dao')
self.data_svc = services['data_svc']
self.web_svc = services['web_svc']
self.ml_svc = services['ml_svc']
self.reg_svc = services['reg_svc']
self.rest_svc = services['rest_svc']
self.tokenizer_sen = nltk.data.load('tokenizers/punkt/english.pickle')

@template('about.html')
async def about(self, request):
Expand Down Expand Up @@ -189,8 +186,6 @@ async def rebuild_ml(self, request):
for i in true_negs:
true_negatives.append(i['sentence'])
list_of_legacy, list_of_techs = await self.data_svc.ml_reg_split(techniques)
self.ml_svc.build_pickle_file(self, list_of_techs, techniques, true_negatives, force=True)
self.ml_svc.build_pickle_file(list_of_techs, techniques, true_negatives, force=True)

return {'text': 'ML Rebuilt!'}


10 changes: 5 additions & 5 deletions models/attack_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@
"has been known to dump credentials."
],
"id": "T1003",
"name": "Credential Dumping",
"name": "OS Credential Dumping",
"similar_words": [
"Credential Dumping"
]
Expand Down Expand Up @@ -1248,7 +1248,7 @@
"used SMTP as a communication channel in various implants, initially using self-registered Google Mail accounts and later compromised email servers of its victims. Later implants such as use a blend of HTTP and other legitimate channels, depending on module configuration."
],
"id": "T1071",
"name": "Standard Application Layer Protocol",
"name": "Application Layer Protocol",
"similar_words": [
"Standard Application Layer Protocol"
]
Expand Down Expand Up @@ -2820,7 +2820,7 @@
"Malware used by can run commands on the command-line interface."
],
"id": "T1059",
"name": "Command-Line Interface",
"name": "Command and Scripting Interpreter",
"similar_words": [
"Command-Line Interface"
]
Expand Down Expand Up @@ -3423,7 +3423,7 @@
"transferred compressed and encrypted RAR files containing exfiltration through the established backdoor command and control channel during operations."
],
"id": "T1041",
"name": "Exfiltration Over Command and Control Channel",
"name": "Exfiltration Over C2 Channel",
"similar_words": [
"Exfiltration Over Command and Control Channel"
]
Expand Down Expand Up @@ -4961,7 +4961,7 @@
"has downloaded additional files, including by using a first-stage downloader to contact the C2 server to obtain the second-stage implant."
],
"id": "T1105",
"name": "Remote File Copy",
"name": "Ingress Tool Transfer",
"similar_words": [
"Remote File Copy"
]
Expand Down
27 changes: 20 additions & 7 deletions service/data_svc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import re
import json
import logging
from taxii2client import Collection
import uuid

from stix2 import TAXIICollectionSource, Filter

try:
# This is the appropriate import for taxii-client v2.x; this might fail in older taxii-client versions
from taxii2client.v20 import Collection
except ModuleNotFoundError:
# The original import statement used in case of error
from taxii2client import Collection


def defang_text(text):
"""
Expand Down Expand Up @@ -52,7 +60,8 @@ async def insert_attack_stix_data(self):
references[i["id"]] = {"name": i["name"], "id": i["external_references"][0]["external_id"],
"example_uses": [],
"description": i['description'].replace('<code>', '').replace('</code>', '').replace(
'\n', '').encode('ascii', 'ignore').decode('ascii'),
'\n', '').encode('ascii', 'ignore').decode('ascii') if hasattr(i, "description")
else 'No description provided',
"similar_words": [i["name"]]}

for i in attack["relationships"]:
Expand Down Expand Up @@ -94,11 +103,13 @@ async def insert_attack_stix_data(self):
await self.dao.insert('attack_uids', dict(uid=k, description=defang_text(v['description']), tid=v['id'],
name=v['name']))
if 'regex_patterns' in v:
[await self.dao.insert('regex_patterns', dict(uid=k, regex_pattern=defang_text(x))) for x in
v['regex_patterns']]
[await self.dao.insert('regex_patterns', dict(uid=str(uuid.uuid4()), attack_uid=k,
regex_pattern=defang_text(x)))
for x in v['regex_patterns']]
if 'similar_words' in v:
[await self.dao.insert('similar_words', dict(uid=k, similar_word=defang_text(x))) for x in
v['similar_words']]
[await self.dao.insert('similar_words', dict(uid=str(uuid.uuid4()), attack_uid=k,
similar_word=defang_text(x)))
for x in v['similar_words']]
if 'false_negatives' in v:
[await self.dao.insert('false_negatives', dict(uid=k, false_negative=defang_text(x))) for x in
v['false_negatives']]
Expand Down Expand Up @@ -139,7 +150,9 @@ async def insert_attack_json_data(self, buildfile):
loaded_items[item['id']] = {'id': tid, 'name': item['name'],
'examples': [],
'similar_words': [],
'description': item['description'],
'description': item['description']
if hasattr(item, 'description')
else 'No description provided',
'example_uses': []}
else:
logging.critical('[!] Error: multiple MITRE sources: {} {}'.format(item['id'], items))
Expand Down
146 changes: 104 additions & 42 deletions service/ml_svc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import logging
import nltk
import os
import pandas as pd
import pickle
import random

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import os, pickle, random
import nltk
import logging
import asyncio
from sklearn.model_selection import train_test_split


class MLService:
Expand All @@ -16,51 +19,58 @@ def __init__(self, web_svc, dao):
self.dao = dao

async def build_models(self, tech_name, techniques, true_negatives):
"""Function to build Logistic Regression Classification models based off of the examples provided"""
"""Function to build Logistic Regression Classification models based off of the examples provided."""
lst1, lst2, false_list, sampling = [], [], [], []
getuid = ""
getuid = ''
len_truelabels = 0

for k, v in techniques.items():
if v['name'] == tech_name:
for i in v['example_uses']:
lst1.append(self.web_svc.tokenize(self, i))
lst1.append(await self.web_svc.tokenize(i))
lst2.append(True)
len_truelabels += 1
getuid = k
# collect the false_positive samples here too, which are the incorrectly labeled texts from reviewed reports, we will include these in the Negative Class.
for fp in v['false_positives']:
sampling.append(fp)
# Collect the false_positive samples here too, which are the incorrectly labeled texts from
# reviewed reports, we will include these in the Negative Class.
if 'false_positives' in v.keys():
for fp in v['false_positives']:
sampling.append(fp)
else:
for i in v['example_uses']:
false_list.append(self.web_svc.tokenize(self, i))
false_list.append(await self.web_svc.tokenize(i))

# at least 90% of total labels for both classes, use this for determining how many labels to use for classifier's negative class
# At least 90% of total labels for both classes
# use this for determining how many labels to use for classifier's negative class
kval = int((len_truelabels * 10))

# make first half random set of true negatives that have no relation/label to ANY technique
sampling.extend(random.choices(true_negatives, k=kval))
# Make first half random set of true negatives that have no relation/label to ANY technique
# Need if-checks because an empty list will cause an error with random.choices()
if true_negatives:
sampling.extend(random.choices(true_negatives, k=kval))

# do second random half set, these are true/positive labels for OTHER techniques, use list obtained from above
sampling.extend(random.choices(false_list, k=kval))
# Do second random half set, these are true/positive labels for OTHER techniques, use list obtained from above
if false_list:
sampling.extend(random.choices(false_list, k=kval))

# Finally, create the Negative Class for this technique's classification model, include False as the labels for this training data
# Finally, create the Negative Class for this technique's classification model
# and include False as the labels for this training data
for false_label in sampling:
lst1.append(self.web_svc.tokenize(self, false_label))
lst1.append(await self.web_svc.tokenize(false_label))
lst2.append(False)

# convert into a dataframe
# Convert into a dataframe
df = pd.DataFrame({'text': lst1, 'category': lst2})

# build model based on that technique
# Build model based on that technique
cv = CountVectorizer(max_features=2000)
X = cv.fit_transform(df['text']).toarray()
y = df['category']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
logreg = LogisticRegression(max_iter=2500, solver='lbfgs')
logreg.fit(X_train, y_train)

print("{} - {}".format(tech_name, logreg.score(X_test, y_test)))
logging.info('Technique test score: {} - {}'.format(tech_name, logreg.score(X_test, y_test)))
return (cv, logreg)

async def analyze_document(self, cv, logreg, sentences):
Expand All @@ -73,27 +83,66 @@ async def analyze_document(self, cv, logreg, sentences):
df2['category'] = y_pred.tolist()
return df2

async def build_pickle_file(self, list_of_techs, techniques, force=False):
if not os.path.isfile('models/model_dict.p') or force:
model_dict = {}
total = len(list_of_techs)
count = 1
print(
"Building Classification Models.. This could take anywhere from ~30-60+ minutes. Please do not close terminal.")
for i in list_of_techs:
print('[#] Building.... {}/{}'.format(count, total))
count += 1
model_dict[i] = self.build_models(self, i, techniques)
print('[#] Saving models to pickled file: model_dict.p')
pickle.dump(model_dict, open('models/model_dict.p', 'wb'))
else:
print('[#] Loading models from pickled file: model_dict.p')
model_dict = pickle.load(open('models/model_dict.p', 'rb'))
async def build_pickle_file(self, list_of_techs, techniques, true_negatives, force=False):
"""Returns the classification models for the data provided."""
# Specify the location of the models file
dict_loc = 'models/model_dict.p'
# If we are not forcing the models to be rebuilt, obtain the previously used models
if not force:
model_dict = self.get_pre_saved_models(dict_loc)
# If the models were obtained successfully, return them
if model_dict:
return model_dict
# Else proceed with building the models
model_dict = {}
total = len(list_of_techs)
count = 1
logging.info('Building Classification Models.. This could take anywhere from ~30-60+ minutes. '
'Please do not close terminal.')
for i in list_of_techs:
logging.info('[#] Building.... {}/{}'.format(count, total))
count += 1
model_dict[i] = await self.build_models(i, techniques, true_negatives)
logging.info('[#] Saving models to pickled file: ' + os.path.basename(dict_loc))
# Save the newly-built models
with open(dict_loc, 'wb') as saved_dict:
pickle.dump(model_dict, saved_dict)
return model_dict

@staticmethod
def get_pre_saved_models(dictionary_location):
"""Function to retrieve previously-saved models via pickle."""
# Check the given location is a valid filepath
if os.path.isfile(dictionary_location):
logging.info('[#] Loading models from pickled file: ' + os.path.basename(dictionary_location))
# Open the model file
with open(dictionary_location, 'rb') as pre_saved_dict:
# Attempt to load the model file's contents
try:
# A UserWarning can appear stating the risks of using a different pickle version from sklearn
return pickle.load(pre_saved_dict)
# sklearn.linear_model.logistic has been required in a previous run; might be related to UserWarning
except ModuleNotFoundError as mnfe:
logging.warning('Could not load existing models: ' + str(mnfe))
# An empty file has been passed to pickle.load()
except EOFError as eofe:
logging.warning('Existing models file may be empty: ' + str(eofe))
# The provided location was not a valid filepath
else:
logging.warning('Invalid location given for existing models file.')
# return None if pickle.load() was not successful or a valid filepath was not provided
return None

async def analyze_html(self, list_of_techs, model_dict, list_of_sentences):
for i in list_of_techs:
cv, logreg = model_dict[i]
# If an older model_dict has been loaded, its keys may be out of sync with list_of_techs
try:
cv, logreg = model_dict[i]
except KeyError: # Report to user if a model can't be retrieved
logging.warning('Technique \'' + i + '\' has no model to analyse with. You can try deleting/moving '
'models/model_dict.p to trigger re-build of models.')
# Skip this technique and move onto the next one
continue
final_df = await self.analyze_document(cv, logreg, list_of_sentences)
count = 0
for vals in final_df['category']:
Expand All @@ -109,14 +158,28 @@ async def ml_techniques_found(self, report_id, sentence):
found_status="true"))
for technique in sentence['ml_techniques_found']:
attack_uid = await self.dao.get('attack_uids', dict(name=technique))
# If the attack cannot be found via the 'name' column, try the 'tid' column
if not attack_uid:
attack_uid = await self.dao.get('attack_uids', dict(tid=technique))
# If the attack has still not been retrieved, try searching the similar_words table
if not attack_uid:
similar_word = await self.dao.get('similar_words', dict(similar_word=technique))
# If a similar word was found, use its attack_uid to lookup the attack_uids table
if similar_word and similar_word[0] and similar_word[0]['attack_uid']:
attack_uid = await self.dao.get('attack_uids', dict(uid=similar_word[0]['attack_uid']))
# If the attack has still not been retrieved, report to user that this cannot be saved against the sentence
if not attack_uid:
logging.warning(' '.join(('Sentence ID:', str(sentence_id), 'ML Technique:', technique, '- Technique'
+ 'could not be retrieved from the database; cannot save this technique\'s '
'association with the sentence.')))
# Skip this technique and continue with the next one
continue
attack_technique = attack_uid[0]['uid']
attack_technique_name = '{} (m)'.format(attack_uid[0]['name'])
attack_tid = attack_uid[0]['tid']
await self.dao.insert('report_sentence_hits',
dict(uid=sentence_id, attack_uid=attack_technique,
attack_technique_name=attack_technique_name, report_uid=report_id, attack_tid = attack_tid))
attack_technique_name=attack_technique_name, report_uid=report_id, attack_tid=attack_tid))

async def get_true_negs(self):
true_negs = await self.dao.get('true_negatives')
Expand Down Expand Up @@ -147,5 +210,4 @@ async def check_nltk_packs(self):
except LookupError:
logging.warning('Could not find the stopwords pack, downloading now')
nltk.download('stopwords')


self.web_svc.initialise_tokenizer()
Loading