diff --git a/project/main.py b/project/main.py index 7c0d198..8d222ae 100644 --- a/project/main.py +++ b/project/main.py @@ -9,7 +9,6 @@ import ipaddress import re import time -from .auto_report import check_all_rules, get_real_ip #The new report module. Will check rules + report from socket import gethostbyaddr, herror from flask import request, redirect, url_for, render_template, jsonify, Response, send_from_directory, g, after_this_request, flash, Blueprint, current_app, make_response from flask_login import login_required, current_user @@ -17,6 +16,7 @@ from functools import wraps from dateutil.parser import parse from math import ceil #for pagination +from .auto_report import check_all_rules, get_real_ip #auto_report.py. from . import cache main = Blueprint('main', __name__) @@ -100,7 +100,7 @@ def decorated_function(*args, **kwargs): # Validate an IP address def validate_ip_query(_ip): - """ Validate queried IP. """ + """ Validate queried IP. (For a glob query, not for validating an ip addr itself)""" # IPv4/v6 chars + GLOB chars. Loose max length to account for glob queries, otherwise 39. ip_pattern = r'^[0-9A-Fa-f.:*\[\]\-^]{1,60}$' regex = re.compile(ip_pattern) @@ -148,6 +148,15 @@ def validate_header_key(_hk): else: return False +def validate_regex(pattern: str) -> bool: + ''' Validate a regex pattern. Return True if valid, otherwise false. ''' + try: + re.compile(pattern) + except re.error as e: + #raise ValidationError(f"Invalid regular expression: {e}") + return False + return True + ### SQLite function callbacks #Callback for SQLite CIDR user function @@ -160,8 +169,8 @@ def cidr_match(item, subnet): except ValueError as e: return False -#Callback for SQLite REGEXP function def regexp(expr, item): + ''' Callback for SQLite REGEXP user function. ''' reg = re.compile(expr) return reg.search(str(item)) is not None @@ -186,11 +195,12 @@ def compare_time_b(timestamp, num_days): ### Other utility functions -def get_pagination_data(stats, view_args: bool = False): +def paginate_data(stats, view_args: bool = False): ''' Pagination logic. Returns a dict of variables used for pagination of results. Args: : Results from query in the route; final data to be sent to the HTML template. - : Bool. Routes that define view args need to use request.view_args instead of request.args to build the pagination. Set to 1/True for these routes, otherwise 0/False.''' + : Bool. Routes that define view args (i.e., `def route(view_arg):`) need to use request.view_args instead of request.args to build the pagination. + Set to 1/True for these routes, otherwise 0/False.''' page = int(request.args.get('page', 1)) items_per_page = int(request.args.get('per_page', 100)) total_items = len(stats) @@ -205,7 +215,7 @@ def get_pagination_data(stats, view_args: bool = False): args_for_pagination = request.view_args else: args_for_pagination = request.args.to_dict() - # Remove the page# from the URL, so we can add a new one to the pagination links + # Remove the page# from current request args, so we can pass it separately to the template. if 'page' in args_for_pagination: del args_for_pagination['page'] @@ -216,6 +226,28 @@ def get_pagination_data(stats, view_args: bool = False): 'args_for_pagination': args_for_pagination } +def create_error_response(status_code: int = 400, error_msg: str = None, help_msg: str = None) -> tuple: + ''' Create a tuple containing an error response content. ''' + # Set error_name based on given status code + match status_code: + case 400: + err_name = 'Bad request' + case 401: + err_Name = 'Unauthorized' #Not authenticated + case 403: + err_name = 'Forbidden' #Authenticated but not authorized + # Set the response content + response_content = ( + [ + err_name, + {'error': error_msg}, + {'help': help_msg}, + ], + status_code, + ) + + return response_content + ### Flask app routes # Will use this in a couple places so I don't have to list them all out @@ -402,8 +434,7 @@ def stats(): if records_limit.isnumeric(): records_limit = int(records_limit) else: - flash('Bad request: `limit` must be a positive integer.', 'error') - return render_template('index.html') + return create_error_response(400, 'Invalid param', '`Limit` must be a positive integer.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -465,7 +496,7 @@ def stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats) + pagination_data = paginate_data(stats) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -520,8 +551,7 @@ def ipStats(ipAddr): """ Get records of an individual IP. The IP column on stats page will link to this route. """ # Validate the given IP first: if not validate_ip_query(ipAddr): - flash('Bad request: Contains invalid characters.', 'errorn') - return render_template('index.html') + return create_error_response(400, 'Invalid param', 'Query contains invalid characters.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -535,7 +565,7 @@ def ipStats(ipAddr): c.close() conn.close() - pagination_data = get_pagination_data(stats, True) + pagination_data = paginate_data(stats, True) flash('Note: Use * for wildcard, i.e. /stats/ip/1.2.3.*', 'info') return render_template('stats.html', @@ -558,7 +588,7 @@ def subnet_stats(): ipaddress.ip_network(test_subnet, strict=False) except ValueError as e: logging.error(f'Invalid CIDR: {str(e)}') - return ('Bad request: Must be a valid CIDR subnet.', 400) + return create_error_response(400, 'Invalid param', 'Param `net` must be a valid CIDR subnet.') """ Get rows that originated from a given CIDR subnet. """ with sqlite3.connect(requests_db) as conn: @@ -579,7 +609,7 @@ def subnet_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -596,8 +626,7 @@ def top_ten_ips(): """ Return top ten most common IPs. """ _num_of_ips = request.args.get('limit', '10') # num of IPs to include, i.e. Top X IPs. default 10 if not _num_of_ips.isnumeric(): - flash('Bad request: `limit` must be numeric', 'error') - return render_template('index.html') + return create_error_response(400, 'Invalid param', 'Param `limit` must be a positive integer.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -640,8 +669,7 @@ def methodStats(method): """ Get records by request method """ # Flash an error message if querying for a method not in db if method not in HTTP_METHODS: - flash('Bad request. Must query for a valid HTTP method, try /method/GET or /method/POST, etc.', 'error') - return render_template('index.html') + return create_error_response(400, 'Invalid param', 'Param `method` must be a valid HTTP method. Try /method/GET or /method/POST, etc') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -655,7 +683,7 @@ def methodStats(method): c.close() conn.close() - pagination_data = get_pagination_data(stats, True) + pagination_data = paginate_data(stats, True) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -685,7 +713,7 @@ def uaStats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -715,7 +743,7 @@ def urlStats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) flash('Note: Use * for wildcard, i.e. url=*.example.com/*', 'info') return render_template('stats.html', @@ -747,7 +775,7 @@ def path_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) flash('Note: Use % for wildcard, i.e. path=/admin/%', 'info') return render_template('stats.html', @@ -778,7 +806,7 @@ def host_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -808,7 +836,7 @@ def queriesStats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -838,7 +866,7 @@ def bodyStats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) #flash('Note: LIKE query- %25 for wildcard', 'info') return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -853,8 +881,11 @@ def bodyStats(): @main.route('/stats/body_raw', methods = ['GET']) @login_required def bodyRawStats(): - """ Get records matching the request body. Regex query. (body_raw column, stored as blob) """ - body = unquote(request.args.get('body', '')) + ''' Get records matching the request body. Regex query. (body_raw column, stored as blob) ''' + body_pattern = unquote(request.args.get('body', '')) + + if not body_pattern or not validate_regex(body_pattern): + return create_error_response(400, 'Invalid param', 'Missing or invalid regex pattern.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -862,13 +893,15 @@ def bodyRawStats(): # Query for matching request body, order by most recent. conn.create_function("REGEXP", 2, regexp) sql_query = '''SELECT * FROM bots WHERE body_raw REGEXP (?) ORDER BY id DESC;''' - data_tuple = (body,) + data_tuple = (body_pattern,) + c.execute(sql_query, data_tuple) stats = c.fetchall() + c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -876,8 +909,8 @@ def bodyRawStats(): total_pages = pagination_data['total_pages'], args_for_pagination = pagination_data['args_for_pagination'], totalHits = len(stats), - statName = f"Request body like:", - subtitle = f'{body}', + statName = f"Request body regex:", + subtitle = f'{body_pattern}', ) @main.route('/stats/content-type', methods = ['GET']) @@ -906,7 +939,7 @@ def content_type_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -954,7 +987,7 @@ def date_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -973,8 +1006,7 @@ def reported_stats(): reported_status = request.args.get('reported', '1') # Validate if reported_status not in ('0', '1'): - flash('Bad request. Try reported=0 or reported=1', 'error') - return render_template('index.html') + return create_error_response(400, 'Invalid param', 'Param `reported` must equal 0 or 1.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -1007,7 +1039,7 @@ def reported_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) # Flash a message based on reported or unreported if reported_status == '1' and top_reported: @@ -1049,7 +1081,7 @@ def proxy_connection_header_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -1082,7 +1114,7 @@ def header_string_search(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -1112,7 +1144,7 @@ def hostname_stats(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) flash('Note: Includes hostnames that are subdomains of the query.', 'info') return render_template('stats.html', @@ -1130,7 +1162,7 @@ def headers_single_json(request_id): """ Pull headers from db by ID#, and display on headers_json.html. """ if not request_id or not request_id.isnumeric(): - return ('Bad request: ID must be numeric.', 400) + return create_error_response(400, 'Invalid param', 'ID must be numeric.') request_id = int(request_id) next_request_id = request_id + 1 @@ -1163,7 +1195,7 @@ def headers_single_json(request_id): except TypeError as e: # Catch TypeError when headers_json field is NULL (i.e. database isn't updated) # Only need this until I update the existing database. - flash('Bad request; ID doesn\'t exist.', 'error') + flash('Error: headers_json field is NULL for this row. Are you using an old database?', 'error') return render_template('index.html') #logging.debug(f'Request headers: {data}') @@ -1173,7 +1205,7 @@ def headers_single_json(request_id): request_id = request_id, next_request_id = next_request_id, prev_request_id = prev_request_id - ) + ) @main.route('/stats/headers/key_search', methods = ['GET']) @login_required @@ -1181,7 +1213,7 @@ def headers_key_search(): """ Find requests which include a given header. """ header_name = request.args.get('key', 'no input') if not validate_header_key(header_name): - return (['Bad Request', {'Error': 'Invalid characters'}], 400) + return create_error_response(400, 'Invalid param', 'Invalid header name; value may contain only letters and hyphen.') #query db with sqlite3.connect(requests_db) as conn: @@ -1209,7 +1241,7 @@ def headers_key_search(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -1252,7 +1284,7 @@ def stats_by_id_multiple(): request_id = request.args.get('request_id', '') if not validate_id_glob(request_id): - return ('bad request', 400) + return create_error_response(400, 'Invalid param', '`request_id` must be numeric/glob.') with sqlite3.connect(requests_db) as conn: conn.row_factory = sqlite3.Row @@ -1300,7 +1332,7 @@ def full_search(): c.close() conn.close() - pagination_data = get_pagination_data(stats, False) + pagination_data = paginate_data(stats, False) return render_template('stats.html', stats = pagination_data['stats_on_page'], @@ -1312,6 +1344,48 @@ def full_search(): subtitle = q, ) +@main.route('/search/all_regex', methods = ['GET']) +@login_required +def full_search_regex(): + """ Search entire db (all fields) for given regex pattern. """ + q = request.args.get('q', '') + + if not q or not validate_regex(q): + return create_error_response(400, 'Invalid param', 'Missing or invalid regex pattern for `q`.') + + with sqlite3.connect(requests_db) as conn: + conn.row_factory = sqlite3.Row + conn.create_function("REGEXP", 2, regexp) + c = conn.cursor() + + #Get column names + sql_query = "PRAGMA table_info(bots)" + c.execute(sql_query) + columns = [column[1] for column in c.fetchall()] + + sql_query = "SELECT * FROM bots WHERE " + conditions = [f"{column} REGEXP ?" for column in columns] #If using regexp over LIKE + sql_query += ' OR '.join(conditions) + sql_query += ' ORDER BY id DESC;' + data_list = [q for i in enumerate(columns)] + c.execute(sql_query, data_list) + stats = c.fetchall() + + c.close() + conn.close() + + pagination_data = paginate_data(stats, False) + + return render_template('stats.html', + stats = pagination_data['stats_on_page'], + page = pagination_data['page'], + total_pages = pagination_data['total_pages'], + args_for_pagination = pagination_data['args_for_pagination'], + totalHits = len(stats), + statName = f'All fields - regex', + subtitle = q, + ) + @main.route('/admin/delete_single', methods = ['POST']) @login_required @admin_required @@ -1321,8 +1395,7 @@ def delete_record_by_id(): # Get the id# from the request args, and validate that it's numeric request_id = request.args.get('request_id') if not request_id.isnumeric(): - flash('ID# must be numeric', 'errorn') - return redirect(request.referrer) + return create_error_response(400, 'Invalid param', 'ID# must be numeric.') # Delete the row from database with sqlite3.connect(requests_db) as conn: @@ -1350,8 +1423,7 @@ def delete_login_record(): _id = request.args.get('login_id', '') # Get the id# from the request args, and validate that it's numeric if not _id.isnumeric(): - flash('ID# must be numeric', 'errorn') - return redirect(request.referrer) + return create_error_response(400, 'Invalid param', 'ID# must be numeric.') with sqlite3.connect(requests_db) as conn: c = conn.cursor()