diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89228a63f0f..e11535302b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -92,6 +92,7 @@ repos: - types-python-dateutil - types-requests - types-croniter + - boto3-stubs[s3] exclude: | (?x)( ^boefjes/tools | diff --git a/boefjes/boefjes/__main__.py b/boefjes/boefjes/__main__.py index aef329cf194..a6ab713b15c 100644 --- a/boefjes/boefjes/__main__.py +++ b/boefjes/boefjes/__main__.py @@ -37,7 +37,7 @@ @click.command() @click.argument("worker_type", type=click.Choice([q.value for q in WorkerManager.Queue])) @click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), help="Log level", default="INFO") -def cli(worker_type: str, log_level: str): +def cli(worker_type: str, log_level: str) -> None: logger.setLevel(log_level) logger.info("Starting runtime for %s", worker_type) diff --git a/boefjes/boefjes/api.py b/boefjes/boefjes/api.py index 28bd5d24c92..409508cba3a 100644 --- a/boefjes/boefjes/api.py +++ b/boefjes/boefjes/api.py @@ -33,7 +33,7 @@ def __init__(self, config: Config): self.server = Server(config=config) self.config = config - def stop(self): + def stop(self) -> None: self.terminate() def run(self, *args, **kwargs): @@ -88,7 +88,7 @@ def boefje_input( task_id: UUID, scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), plugin_service: PluginService = Depends(get_plugin_service), -): +) -> BoefjeInput: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -108,7 +108,7 @@ def boefje_output( scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), bytes_client: BytesAPIClient = Depends(get_bytes_client), plugin_service: PluginService = Depends(get_plugin_service), -): +) -> Response: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -127,7 +127,7 @@ def boefje_output( for file in boefje_output.files: raw = base64.b64decode(file.content) # when supported, also save file.name to Bytes - bytes_client.save_raw(task_id, raw, mime_types.union(file.tags)) + bytes_client.save_raw(task_id, raw, mime_types.union(file.tags) if file.tags else mime_types) if boefje_output.status == StatusEnum.COMPLETED: scheduler_client.patch_task(task_id, TaskStatus.COMPLETED) diff --git a/boefjes/boefjes/app.py b/boefjes/boefjes/app.py index 4077edffc45..af30d6319da 100644 --- a/boefjes/boefjes/app.py +++ b/boefjes/boefjes/app.py @@ -80,7 +80,7 @@ def run(self, queue_type: WorkerManager.Queue) -> None: raise - def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue): + def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue) -> None: if task_queue.qsize() > self.settings.pool_size: time.sleep(self.settings.worker_heartbeat) return @@ -189,7 +189,7 @@ def _cleanup_pending_worker_task(self, worker: BaseProcess) -> None: def _worker_args(self) -> tuple: return self.task_queue, self.item_handler, self.scheduler_client, self.handling_tasks - def exit(self, signum: int | None = None): + def exit(self, signum: int | None = None) -> None: try: if signum: logger.info("Received %s, exiting", signal.Signals(signum).name) @@ -238,7 +238,7 @@ def _start_working( handler: Handler, scheduler_client: SchedulerClientInterface, handling_tasks: dict[int, str], -): +) -> None: logger.info("Started listening for tasks from worker[pid=%s]", os.getpid()) while True: diff --git a/boefjes/boefjes/dependencies/encryption.py b/boefjes/boefjes/dependencies/encryption.py index 45e001a9fb1..43c56376b16 100644 --- a/boefjes/boefjes/dependencies/encryption.py +++ b/boefjes/boefjes/dependencies/encryption.py @@ -34,8 +34,8 @@ def __init__(self, private_key: str, public_key: str): def encode(self, contents: str) -> str: encrypted_contents = self.box.encrypt(contents.encode()) - encrypted_contents = base64.b64encode(encrypted_contents) - return encrypted_contents.decode() + base64_encrypted_contents = base64.b64encode(encrypted_contents) + return base64_encrypted_contents.decode() def decode(self, contents: str) -> str: encrypted_binary = base64.b64decode(contents) diff --git a/boefjes/boefjes/katalogus/root.py b/boefjes/boefjes/katalogus/root.py index 4320c6042b4..542e35a854b 100644 --- a/boefjes/boefjes/katalogus/root.py +++ b/boefjes/boefjes/katalogus/root.py @@ -73,22 +73,22 @@ @app.exception_handler(NotFound) -def entity_not_found_handler(request: Request, exc: NotFound): +def entity_not_found_handler(request: Request, exc: NotFound) -> JSONResponse: return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"message": exc.message}) @app.exception_handler(NotAllowed) -def not_allowed_handler(request: Request, exc: NotAllowed): +def not_allowed_handler(request: Request, exc: NotAllowed) -> JSONResponse: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message}) @app.exception_handler(IntegrityError) -def integrity_error_handler(request: Request, exc: IntegrityError): +def integrity_error_handler(request: Request, exc: IntegrityError) -> JSONResponse: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message}) @app.exception_handler(StorageError) -def storage_error_handler(request: Request, exc: StorageError): +def storage_error_handler(request: Request, exc: StorageError) -> JSONResponse: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": exc.message}) diff --git a/boefjes/boefjes/katalogus/settings.py b/boefjes/boefjes/katalogus/settings.py index 4ec2b48cb83..24ef4198a92 100644 --- a/boefjes/boefjes/katalogus/settings.py +++ b/boefjes/boefjes/katalogus/settings.py @@ -6,19 +6,23 @@ @router.get("", response_model=dict) -def list_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): +def list_settings( + organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service) +) -> dict[str, str]: return plugin_service.get_all_settings(organisation_id, plugin_id) @router.put("") def upsert_settings( organisation_id: str, plugin_id: str, values: dict, plugin_service: PluginService = Depends(get_plugin_service) -): +) -> None: with plugin_service as p: p.upsert_settings(values, organisation_id, plugin_id) @router.delete("") -def remove_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): +def remove_settings( + organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service) +) -> None: with plugin_service as p: p.delete_settings(organisation_id, plugin_id) diff --git a/boefjes/boefjes/migrations/env.py b/boefjes/boefjes/migrations/env.py index 883101aee81..0ec4438da97 100644 --- a/boefjes/boefjes/migrations/env.py +++ b/boefjes/boefjes/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config, pool from boefjes.config import settings -from boefjes.sql.db_models import SQL_BASE +from boefjes.sql.db import SQL_BASE # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py index f76286dee3c..03489b76f85 100644 --- a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py +++ b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py @@ -42,7 +42,7 @@ def upgrade() -> None: # ### end Alembic commands ### -def upgrade_encrypted_settings(conn: Connection): +def upgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): @@ -90,7 +90,7 @@ def downgrade() -> None: # ### end Alembic commands ### -def downgrade_encrypted_settings(conn: Connection): +def downgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): diff --git a/boefjes/boefjes/plugins/kat_crt_sh/main.py b/boefjes/boefjes/plugins/kat_crt_sh/main.py index cb2d891bae0..318a45efaf1 100644 --- a/boefjes/boefjes/plugins/kat_crt_sh/main.py +++ b/boefjes/boefjes/plugins/kat_crt_sh/main.py @@ -31,7 +31,13 @@ ) -def request_certs(search_string, search_type="Identity", match="=", deduplicate=True, json_output=True) -> str: +def request_certs( + search_string: str, + search_type: str = "Identity", + match: str = "=", + deduplicate: bool = True, + json_output: bool = True, +) -> str: """Queries the public service CRT.sh for certificate information the searchtype can be specified and defaults to Identity. the type of sql matching can be specified and defaults to "=" diff --git a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py index 92670e81f45..d93a87dc9c9 100644 --- a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py +++ b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py @@ -4,12 +4,12 @@ from octopoes.models import Reference from octopoes.models.ooi.findings import CVEFindingType, Finding from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import Version, parse VULNERABLE_RANGES: list[tuple[str, str]] = [("0", "11.8.1.1"), ("11.9.0.0", "11.9.1.1"), ("11.10.0.0", "11.10.0.2")] -def extract_js_version(html_content: str) -> version.Version | bool: +def extract_js_version(html_content: str) -> Version | bool: telltale = "/mifs/scripts/auth.js?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -20,10 +20,10 @@ def extract_js_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) -def extract_css_version(html_content: str) -> version.Version | bool: +def extract_css_version(html_content: str) -> Version | bool: telltale = "/mifs/css/windowsAllAuth.css?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -34,7 +34,7 @@ def extract_css_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) def strip_vsp_and_build(url: str) -> Iterable[str]: @@ -47,9 +47,7 @@ def strip_vsp_and_build(url: str) -> Iterable[str]: yield part -def is_vulnerable_version( - vulnerable_ranges: list[tuple[version.Version, version.Version]], detected_version: version.Version -) -> bool: +def is_vulnerable_version(vulnerable_ranges: list[tuple[Version, Version]], detected_version: Version) -> bool: return any(start <= detected_version < end for start, end in vulnerable_ranges) @@ -70,11 +68,11 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield software_instance if js_detected_version: vulnerable = is_vulnerable_version( - [(version.parse(start), version.parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version + [(parse(start), parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version ) else: # The CSS version only included the first two parts of the version number so we don't know the patch level - vulnerable = css_detected_version < version.parse("11.8") + vulnerable = css_detected_version < parse("11.8") if vulnerable: finding_type = CVEFindingType(id="CVE-2023-35078") finding = Finding( diff --git a/boefjes/boefjes/plugins/kat_dnssec/main.py b/boefjes/boefjes/plugins/kat_dnssec/main.py index 9fc385134f2..4af36d11220 100644 --- a/boefjes/boefjes/plugins/kat_dnssec/main.py +++ b/boefjes/boefjes/plugins/kat_dnssec/main.py @@ -2,7 +2,7 @@ import subprocess -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: input_ = boefje_meta["arguments"]["input"] domain = input_["name"] @@ -19,6 +19,4 @@ def run(boefje_meta: dict): output.check_returncode() - results = [({"openkat/dnssec-output"}, output.stdout)] - - return results + return [({"openkat/dnssec-output"}, output.stdout)] diff --git a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py index a25bf73f0d1..262ba102cbe 100644 --- a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py +++ b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from boefjes.job_models import NormalizerDeclaration, NormalizerOutput -from octopoes.models import Reference +from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network from octopoes.models.ooi.web import URL @@ -30,7 +30,7 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield from process_csv(raw, reference_cache) -def process_csv(csv_raw_data, reference_cache) -> Iterable[NormalizerOutput]: +def process_csv(csv_raw_data: bytes, reference_cache: dict) -> Iterable[NormalizerOutput]: csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) object_type = get_object_type(csv_data) @@ -74,7 +74,7 @@ def get_object_type(csv_data: io.StringIO) -> str: def get_ooi_from_csv( - ooi_type_name: str, values: dict[str, str], reference_cache + ooi_type_name: str, values: dict[str, str], reference_cache: dict ) -> tuple[OOIType, list[NormalizerDeclaration]]: skip_properties = ("object_type", "scan_profile", "primary_key") @@ -85,7 +85,7 @@ def get_ooi_from_csv( if field not in skip_properties ] - kwargs = {} + kwargs: dict[str, Reference | str | None] = {} extra_declarations: list[NormalizerDeclaration] = [] for field, is_reference, required in ooi_fields: @@ -109,7 +109,7 @@ def get_ooi_from_csv( return ooi_type(**kwargs), extra_declarations -def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache): +def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache: dict) -> OOI: ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), OOI_TYPES.keys())) # get from cache diff --git a/boefjes/boefjes/plugins/kat_masscan/main.py b/boefjes/boefjes/plugins/kat_masscan/main.py index 6ef83ece35c..f2604d00dc3 100644 --- a/boefjes/boefjes/plugins/kat_masscan/main.py +++ b/boefjes/boefjes/plugins/kat_masscan/main.py @@ -10,7 +10,7 @@ FILE_PATH = "/tmp/output.json" # noqa: S108 -def run_masscan(target_ip) -> bytes: +def run_masscan(target_ip: str) -> bytes: """Run Masscan in Docker.""" client = docker.from_env() diff --git a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py index 3b14b10292c..9e9d052c2a0 100644 --- a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py +++ b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py @@ -5,7 +5,7 @@ TOP_PORTS_DEFAULT = 250 -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: top_ports_key = "TOP_PORTS" if boefje_meta["boefje"]["id"] == "nmap-udp": top_ports_key = "TOP_PORTS_UDP" @@ -22,6 +22,4 @@ def run(boefje_meta: dict): output.check_returncode() - results = [({"openkat/nmap-output"}, output.stdout.decode())] - - return results + return [({"openkat/nmap-output"}, output.stdout.decode())] diff --git a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py index 48538de257c..bad69556c3b 100644 --- a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py +++ b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py @@ -5,6 +5,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -41,7 +42,10 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: elif response.status_code in [301, 302, 307, 308]: uri = response.headers["Location"] response = requests.get(uri, stream=True, timeout=30, verify=False) # noqa: S501 - ip = response.raw._connection.sock.getpeername()[0] + if response.raw._connection: + ip = response.raw._connection.sock.getpeername()[0] + else: + ip = "" results[path] = { "content": response.content.decode(), "url": response.url, @@ -53,7 +57,7 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: return [(set(), json.dumps(results))] -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, verify=False, allow_redirects=False ) diff --git a/boefjes/boefjes/plugins/kat_snyk/check_version.py b/boefjes/boefjes/plugins/kat_snyk/check_version.py index 442978e0f4c..a758fc12246 100644 --- a/boefjes/boefjes/plugins/kat_snyk/check_version.py +++ b/boefjes/boefjes/plugins/kat_snyk/check_version.py @@ -70,7 +70,7 @@ def check_version(version1: str, version2: str) -> VersionCheck: return check_version(version1_split[1], version2_split[1]) -def check_version_agains_versionlist(my_version: str, all_versions: list[str]): +def check_version_agains_versionlist(my_version: str, all_versions: list[str]) -> tuple[bool, list[str] | None]: lowerbound = all_versions.pop(0).strip() upperbound = None @@ -164,10 +164,12 @@ def check_version_agains_versionlist(my_version: str, all_versions: list[str]): return True, all_versions -def check_version_in(version: str, versions: str): +def check_version_in(version: str, versions: str) -> bool: if not version: return False - all_versions = versions.split(",") # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks + all_versions: list[str] | None = versions.split( + "," + ) # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks in_range = False while not in_range and all_versions: in_range, all_versions = check_version_agains_versionlist(version, all_versions) diff --git a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py index f9fef3921fa..65d124c9d60 100644 --- a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py @@ -7,6 +7,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -54,9 +55,9 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: body_mimetypes.add(content_type) # Pick up the content type for the body from the server and split away encodings to make normalization easier - content_type = content_type.split(";") - if content_type[0] in ALLOWED_CONTENT_TYPES: - body_mimetypes.add(content_type[0]) + content_type_splitted = content_type.split(";") + if content_type_splitted[0] in ALLOWED_CONTENT_TYPES: + body_mimetypes.add(content_type_splitted[0]) # in case of a full response object, we hexdump to avoid issues with binary data or different encoding response_dump = json.dumps(create_response_object(response)) @@ -87,7 +88,7 @@ def create_response_object(response: requests.Response) -> dict: } -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, verify=False, allow_redirects=False ) diff --git a/boefjes/boefjes/plugins/kat_webpage_capture/main.py b/boefjes/boefjes/plugins/kat_webpage_capture/main.py index 6ee7e8dd44e..afecb4bbc89 100644 --- a/boefjes/boefjes/plugins/kat_webpage_capture/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_capture/main.py @@ -10,11 +10,11 @@ class WebpageCaptureException(Exception): """Exception raised when webpage capture fails.""" - def __init__(self, message, container_log=None): + def __init__(self, message: str, container_log: str): self.message = message self.container_log = container_log - def __str__(self): + def __str__(self) -> str: return str(self.message) + "\n\nContainer log:\n" + self.container_log diff --git a/boefjes/boefjes/plugins/models.py b/boefjes/boefjes/plugins/models.py index e7753604486..364f14fafb9 100644 --- a/boefjes/boefjes/plugins/models.py +++ b/boefjes/boefjes/plugins/models.py @@ -102,7 +102,7 @@ def hash_path(path: Path) -> str: return folder_hash.hexdigest() -def _default_mime_types(boefje: Boefje): +def _default_mime_types(boefje: Boefje) -> set: mime_types = {f"boefje/{boefje.id}"} if boefje.version is not None: diff --git a/boefjes/boefjes/runtime_interfaces.py b/boefjes/boefjes/runtime_interfaces.py index 0a8375bdb86..70bacfb5554 100644 --- a/boefjes/boefjes/runtime_interfaces.py +++ b/boefjes/boefjes/runtime_interfaces.py @@ -4,7 +4,7 @@ class Handler: - def handle(self, item: BoefjeMeta | NormalizerMeta): + def handle(self, item: BoefjeMeta | NormalizerMeta) -> None: raise NotImplementedError() diff --git a/boefjes/boefjes/sql/db.py b/boefjes/boefjes/sql/db.py index 97fa858b75b..ea3fad670c2 100644 --- a/boefjes/boefjes/sql/db.py +++ b/boefjes/boefjes/sql/db.py @@ -52,5 +52,5 @@ def session_managed_iterator(service_factory: Callable[[Session], Any]) -> Itera class ObjectNotFoundException(Exception): - def __init__(self, cls: type | UnionType, **kwargs): # type: ignore + def __init__(self, cls: type | UnionType, **kwargs): super().__init__(f"The object of type {cls} was not found for query parameters {kwargs}") diff --git a/boefjes/boefjes/sql/organisation_storage.py b/boefjes/boefjes/sql/organisation_storage.py index 2124f9a2901..6c69f28abed 100644 --- a/boefjes/boefjes/sql/organisation_storage.py +++ b/boefjes/boefjes/sql/organisation_storage.py @@ -59,7 +59,7 @@ def to_organisation(organisation_in_db: OrganisationInDB) -> Organisation: return Organisation(id=organisation_in_db.id, name=organisation_in_db.name) -def create_organisation_storage(session) -> SQLOrganisationStorage: +def create_organisation_storage(session: Session) -> SQLOrganisationStorage: return SQLOrganisationStorage(session, settings) diff --git a/bytes/bytes/api/metrics.py b/bytes/bytes/api/metrics.py index d6a44010b4a..0d1f760cedf 100644 --- a/bytes/bytes/api/metrics.py +++ b/bytes/bytes/api/metrics.py @@ -23,7 +23,7 @@ logger = structlog.get_logger(__name__) -def ignore_arguments_key(meta_repository: MetaDataRepository): +def ignore_arguments_key(meta_repository: MetaDataRepository) -> str: return "" diff --git a/bytes/bytes/api/root.py b/bytes/bytes/api/root.py index 53661005c33..399bfa588f7 100644 --- a/bytes/bytes/api/root.py +++ b/bytes/bytes/api/root.py @@ -47,7 +47,7 @@ def health() -> ServiceHealth: @router.get("/metrics", dependencies=[Depends(authenticate_token)]) -def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)): +def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)) -> Response: collector_registry = get_registry(meta_repository) data = prometheus_client.generate_latest(collector_registry) diff --git a/bytes/bytes/api/router.py b/bytes/bytes/api/router.py index 0b6f0cb901a..716177cc8d6 100644 --- a/bytes/bytes/api/router.py +++ b/bytes/bytes/api/router.py @@ -271,7 +271,7 @@ def get_raw_count_per_mime_type( return cached_counts_per_mime_type(meta_repository, query_filter) -def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter): +def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter) -> str: """Helper to not cache based on the stateful meta_repository, but only use the query parameters as a key.""" return query_filter.model_dump_json() diff --git a/bytes/bytes/database/migrations/env.py b/bytes/bytes/database/migrations/env.py index f91d8947d18..15d1e25a528 100644 --- a/bytes/bytes/database/migrations/env.py +++ b/bytes/bytes/database/migrations/env.py @@ -6,7 +6,7 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. from bytes.config import get_settings -from bytes.database.db_models import SQL_BASE +from bytes.database.db import SQL_BASE config = context.config diff --git a/bytes/bytes/database/sql_meta_repository.py b/bytes/bytes/database/sql_meta_repository.py index 93ee36cdc6f..e912d0fb409 100644 --- a/bytes/bytes/database/sql_meta_repository.py +++ b/bytes/bytes/database/sql_meta_repository.py @@ -229,7 +229,7 @@ def create_meta_data_repository() -> Iterator[MetaDataRepository]: class ObjectNotFoundException(Exception): - def __init__(self, cls: type[SQL_BASE], **kwargs): + def __init__(self, cls: type[SQL_BASE], **kwargs: str): super().__init__(f"The object of type {cls} was not found for query parameters {kwargs}") diff --git a/bytes/bytes/models.py b/bytes/bytes/models.py index 03ae39506aa..44ea08c8c85 100644 --- a/bytes/bytes/models.py +++ b/bytes/bytes/models.py @@ -38,10 +38,10 @@ def _validate_timezone_aware_datetime(value: datetime) -> datetime: class MimeType(BaseModel): value: str - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) - def __lt__(self, other: MimeType): + def __lt__(self, other: MimeType) -> bool: return self.value < other.value diff --git a/bytes/bytes/rabbitmq.py b/bytes/bytes/rabbitmq.py index 4ffa498ddd8..8ae5b83f446 100644 --- a/bytes/bytes/rabbitmq.py +++ b/bytes/bytes/rabbitmq.py @@ -41,7 +41,7 @@ def publish(self, event: Event) -> None: logger.info("Published event [event_id=%s] to queue %s", event.event_id, queue_name) - def _check_connection(self): + def _check_connection(self) -> None: if self.connection.is_closed: self.connection = pika.BlockingConnection(pika.URLParameters(self.queue_uri)) self.channel = self.connection.channel() diff --git a/bytes/bytes/raw/file_raw_repository.py b/bytes/bytes/raw/file_raw_repository.py index 6d0fd6c843b..605a6505dc5 100644 --- a/bytes/bytes/raw/file_raw_repository.py +++ b/bytes/bytes/raw/file_raw_repository.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging from pathlib import Path +from typing import TYPE_CHECKING from uuid import UUID import structlog @@ -14,6 +17,10 @@ logger = structlog.get_logger(__name__) +if TYPE_CHECKING: + from mypy_boto3_s3.service_resource import Bucket + + def create_raw_repository(settings: Settings) -> RawRepository: if settings.s3_bucket_name or settings.s3_bucket_prefix: return S3RawRepository( @@ -87,7 +94,7 @@ def __init__( set_boto3_stream_logger("", logging.WARNING) self._s3resource = BotoSession().resource("s3") - def get_or_create_bucket(self, organization: str): + def get_or_create_bucket(self, organization: str) -> Bucket: # Create a bucket, and if it exists already return that instead bucket_name = f"{self.s3_bucket_prefix}{organization}" if self.bucket_per_org else self.s3_bucket_name diff --git a/cveapi/cveapi.py b/cveapi/cveapi.py index 09fcab2aa27..0fdc4d1765f 100644 --- a/cveapi/cveapi.py +++ b/cveapi/cveapi.py @@ -13,7 +13,7 @@ logger = logging.getLogger("cveapi") -def download_files(directory, last_update, update_timestamp): +def download_files(directory: pathlib.Path, last_update: datetime | None, update_timestamp: datetime) -> None: index = 0 client = httpx.Client() error_count = 0 @@ -66,7 +66,7 @@ def download_files(directory, last_update, update_timestamp): logger.info("Downloaded new information of %s CVEs", response_json["totalResults"]) -def run(): +def run() -> None: loglevel = os.getenv("CVEAPI_LOGLEVEL", "INFO") numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): diff --git a/mula/scheduler/__init__.py b/mula/scheduler/__init__.py index 211a13feb2f..3bd881035da 100644 --- a/mula/scheduler/__init__.py +++ b/mula/scheduler/__init__.py @@ -1,4 +1,4 @@ from .app import App -from .version import version +from .version import __version__ -__version__ = version +__all__ = ["App", "__version__"] diff --git a/mula/scheduler/clients/amqp/__init__.py b/mula/scheduler/clients/amqp/__init__.py index ba7a3fab3ef..85194b47cb5 100644 --- a/mula/scheduler/clients/amqp/__init__.py +++ b/mula/scheduler/clients/amqp/__init__.py @@ -1,3 +1,5 @@ from .listeners import Listener, RabbitMQ from .raw_data import RawData from .scan_profile import ScanProfileMutation + +__all__ = ["Listener", "RabbitMQ", "RawData", "ScanProfileMutation"] diff --git a/mula/scheduler/clients/amqp/listeners.py b/mula/scheduler/clients/amqp/listeners.py index 10ab1665476..6a8955a9107 100644 --- a/mula/scheduler/clients/amqp/listeners.py +++ b/mula/scheduler/clients/amqp/listeners.py @@ -181,7 +181,7 @@ def callback( # Submit the message to the thread pool executor self.executor.submit(self.dispatch, channel, method.delivery_tag, body) - def dispatch(self, channel, delivery_tag, body: bytes) -> None: + def dispatch(self, channel: pika.channel.Channel, delivery_tag: int, body: bytes) -> None: # Check if we still have a connection if self.connection is None or self.connection.is_closed: self.logger.debug("No connection available, cannot dispatch message!") diff --git a/mula/scheduler/clients/connector.py b/mula/scheduler/clients/connector.py index 857fd0ad1c4..7fdb34949a2 100644 --- a/mula/scheduler/clients/connector.py +++ b/mula/scheduler/clients/connector.py @@ -1,6 +1,7 @@ import socket import time from collections.abc import Callable +from typing import Any import httpx import structlog @@ -47,7 +48,7 @@ def is_host_healthy(self, host: str, health_endpoint: str) -> bool: self.logger.warning("Exception: %s", exc) return False - def retry(self, func: Callable, *args, **kwargs) -> bool: + def retry(self, func: Callable, *args: Any, **kwargs: Any) -> bool: """Retry a function until it returns True. Args: diff --git a/mula/scheduler/context/__init__.py b/mula/scheduler/context/__init__.py index 61627b2c729..ae686bb6e0c 100644 --- a/mula/scheduler/context/__init__.py +++ b/mula/scheduler/context/__init__.py @@ -1 +1,3 @@ from .context import AppContext + +__all__ = ["AppContext"] diff --git a/mula/scheduler/context/context.py b/mula/scheduler/context/context.py index 00b4d8f16f4..540a86ba545 100644 --- a/mula/scheduler/context/context.py +++ b/mula/scheduler/context/context.py @@ -34,6 +34,9 @@ class AppContext: the schedulers. """ + metrics_qsize: Gauge + metrics_task_status_counts: Gauge + def __init__(self) -> None: """Initializer of the AppContext class.""" self.config: settings.Settings = settings.Settings() diff --git a/mula/scheduler/schedulers/rankers/__init__.py b/mula/scheduler/schedulers/rankers/__init__.py index 9fd14d2b21e..f2467f78e87 100644 --- a/mula/scheduler/schedulers/rankers/__init__.py +++ b/mula/scheduler/schedulers/rankers/__init__.py @@ -1,3 +1,5 @@ from .boefje import BoefjeRanker, BoefjeRankerTimeBased from .normalizer import NormalizerRanker from .ranker import Ranker + +__all__ = ["BoefjeRanker", "NormalizerRanker", "Ranker"] diff --git a/mula/scheduler/schedulers/schedulers/boefje.py b/mula/scheduler/schedulers/schedulers/boefje.py index 5b9fc5653fb..260b5cb40db 100644 --- a/mula/scheduler/schedulers/schedulers/boefje.py +++ b/mula/scheduler/schedulers/schedulers/boefje.py @@ -926,7 +926,7 @@ def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: return True - def get_boefjes_for_ooi(self, ooi) -> list[Plugin]: + def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: """Get available all boefjes (enabled and disabled) for an ooi. Args: diff --git a/mula/scheduler/server/__init__.py b/mula/scheduler/server/__init__.py index b7f2cf59516..09ed39ca17f 100644 --- a/mula/scheduler/server/__init__.py +++ b/mula/scheduler/server/__init__.py @@ -1 +1,3 @@ from .server import Server + +__all__ = ["Server"] diff --git a/mula/scheduler/storage/filters/__init__.py b/mula/scheduler/storage/filters/__init__.py index ddf32f56ef3..eb44f2d36e1 100644 --- a/mula/scheduler/storage/filters/__init__.py +++ b/mula/scheduler/storage/filters/__init__.py @@ -1,3 +1,5 @@ from .casting import cast_expression from .filters import Filter, FilterRequest from .functions import apply_filter + +__all__ = ["cast_expression", "Filter", "FilterRequest", "apply_filter"] diff --git a/mula/scheduler/storage/filters/functions.py b/mula/scheduler/storage/filters/functions.py index 805d7d0591d..f50aeb88fb3 100644 --- a/mula/scheduler/storage/filters/functions.py +++ b/mula/scheduler/storage/filters/functions.py @@ -1,4 +1,5 @@ import sqlalchemy +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.query import Query from sqlalchemy.sql.elements import BinaryExpression @@ -9,7 +10,7 @@ from .operators import FILTER_OPERATORS -def apply_filter(entity, query: Query, filter_request: FilterRequest) -> Query: +def apply_filter(entity: DeclarativeBase, query: Query, filter_request: FilterRequest) -> Query: """Apply the filter criteria to a SQLAlchemy query. Args: diff --git a/mula/scheduler/utils/__init__.py b/mula/scheduler/utils/__init__.py index 47047ddb189..cb8b2af51d1 100644 --- a/mula/scheduler/utils/__init__.py +++ b/mula/scheduler/utils/__init__.py @@ -2,3 +2,5 @@ from .dict_utils import ExpiredError, ExpiringDict, deep_get from .functions import remove_trailing_slash from .thread import ThreadRunner + +__all__ = ["GUID", "ExpiredError", "ExpiringDict", "deep_get", "remove_trailing_slash", "ThreadRunner"] diff --git a/mula/scheduler/utils/cron.py b/mula/scheduler/utils/cron.py index ac632c9ab71..45be36accb4 100644 --- a/mula/scheduler/utils/cron.py +++ b/mula/scheduler/utils/cron.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from croniter import croniter # type: ignore +from croniter import croniter def next_run(expression: str, start_time: datetime | None = None) -> datetime: diff --git a/octopoes/bits/check_cve_2021_41773/bit.py b/octopoes/bits/check_cve_2021_41773/bit.py index 367183e5f8d..3e32a458c9c 100644 --- a/octopoes/bits/check_cve_2021_41773/bit.py +++ b/octopoes/bits/check_cve_2021_41773/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check_cve_2021_41773", diff --git a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py index d07b81268c0..83076aed418 100644 --- a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py +++ b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py @@ -3,7 +3,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/check_hsts_header/bit.py b/octopoes/bits/check_hsts_header/bit.py index 5e8436084f6..6b98f2d86a9 100644 --- a/octopoes/bits/check_hsts_header/bit.py +++ b/octopoes/bits/check_hsts_header/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check-hsts-header", diff --git a/octopoes/bits/check_hsts_header/check_hsts_header.py b/octopoes/bits/check_hsts_header/check_hsts_header.py index 824e6948ace..de92e5e38dc 100644 --- a/octopoes/bits/check_hsts_header/check_hsts_header.py +++ b/octopoes/bits/check_hsts_header/check_hsts_header.py @@ -4,7 +4,7 @@ from octopoes.models import OOI, Reference from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/cipher_classification/cipher_classification.py b/octopoes/bits/cipher_classification/cipher_classification.py index 5f701df86cb..88cf99765b3 100644 --- a/octopoes/bits/cipher_classification/cipher_classification.py +++ b/octopoes/bits/cipher_classification/cipher_classification.py @@ -1,6 +1,7 @@ import csv from collections.abc import Iterator from pathlib import Path +from typing import Any from octopoes.models import OOI from octopoes.models.ooi.findings import Finding, KATFindingType @@ -13,7 +14,7 @@ } -def get_severity_and_reasons(cipher_suite) -> list[tuple[str, str]]: +def get_severity_and_reasons(cipher_suite: str) -> list[tuple[str, str]]: with Path.open(Path(__file__).parent / "list-ciphers-openssl-with-finding-type.csv", newline="") as csvfile: reader = csv.DictReader(csvfile) data = [{k.strip(): v.strip() for k, v in row.items() if k} for row in reader] @@ -76,7 +77,7 @@ def get_highest_severity_and_all_reasons(cipher_suites: dict) -> tuple[str, str] return highest_severity, all_reasons_str -def run(input_ooi: TLSCipher, additional_oois, config) -> Iterator[OOI]: +def run(input_ooi: TLSCipher, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: # Get the highest severity and all reasons for the cipher suite highest_severity, all_reasons = get_highest_severity_and_all_reasons(input_ooi.suites) diff --git a/octopoes/bits/missing_certificate/missing_certificate.py b/octopoes/bits/missing_certificate/missing_certificate.py index 04721d51dc6..ae08a2a214e 100644 --- a/octopoes/bits/missing_certificate/missing_certificate.py +++ b/octopoes/bits/missing_certificate/missing_certificate.py @@ -6,7 +6,7 @@ from octopoes.models.ooi.web import Website -def run(input_ooi: Website, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: Website, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.ip_service.tokenized.service.name.lower() != "https": return diff --git a/octopoes/bits/nxdomain_flag/bit.py b/octopoes/bits/nxdomain_flag/bit.py index 7c593bb563c..d5f44c3fe45 100644 --- a/octopoes/bits/nxdomain_flag/bit.py +++ b/octopoes/bits/nxdomain_flag/bit.py @@ -1,6 +1,6 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-flag", diff --git a/octopoes/bits/nxdomain_flag/nxdomain_flag.py b/octopoes/bits/nxdomain_flag/nxdomain_flag.py index 34401e6ea96..b26aeb368c9 100644 --- a/octopoes/bits/nxdomain_flag/nxdomain_flag.py +++ b/octopoes/bits/nxdomain_flag/nxdomain_flag.py @@ -2,9 +2,9 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import NXDOMAIN def run(input_ooi: Hostname, additional_oois: list[NXDOMAIN], config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/nxdomain_header_flag/bit.py b/octopoes/bits/nxdomain_header_flag/bit.py index 3d4883adec1..296ac1a3580 100644 --- a/octopoes/bits/nxdomain_header_flag/bit.py +++ b/octopoes/bits/nxdomain_header_flag/bit.py @@ -1,7 +1,7 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-header-flag", diff --git a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py index f55c6b54561..7a4f78a004e 100644 --- a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py +++ b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py @@ -2,10 +2,10 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN def run( diff --git a/octopoes/bits/oois_in_headers/bit.py b/octopoes/bits/oois_in_headers/bit.py index 70beea24333..ef2dd5c40d8 100644 --- a/octopoes/bits/oois_in_headers/bit.py +++ b/octopoes/bits/oois_in_headers/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="oois-in-headers", diff --git a/octopoes/bits/oois_in_headers/oois_in_headers.py b/octopoes/bits/oois_in_headers/oois_in_headers.py index cd49dd8dffb..23b5eff3008 100644 --- a/octopoes/bits/oois_in_headers/oois_in_headers.py +++ b/octopoes/bits/oois_in_headers/oois_in_headers.py @@ -7,8 +7,8 @@ from octopoes.models import OOI from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import URL, HTTPHeader, HTTPHeaderURL, Network +from octopoes.models.ooi.network import Network +from octopoes.models.ooi.web import URL, HTTPHeader, HTTPHeaderHostname, HTTPHeaderURL def is_url(input_str): diff --git a/octopoes/bits/retire_js/retire_js.py b/octopoes/bits/retire_js/retire_js.py index 5b57382ffc4..b79a06ec6fc 100644 --- a/octopoes/bits/retire_js/retire_js.py +++ b/octopoes/bits/retire_js/retire_js.py @@ -7,7 +7,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding, RetireJSFindingType from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import parse def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: dict[str, Any]) -> Iterator[OOI]: @@ -40,7 +40,7 @@ def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: di ) -def _check_vulnerabilities(name, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: +def _check_vulnerabilities(name: str, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: vulnerabilities: dict[str, list[str]] = {"CVE": [], "RetireJS": []} processed_name = _process_name(name) found_brands = [brand for brand in known_vulnerabilities if processed_name == _process_name(brand)] @@ -70,10 +70,10 @@ def _hash_identifiers(identifiers: dict[str, str | list[str]]) -> str: def _check_versions(package_version: str, known_vulnerability: dict) -> bool: - below = version.parse(package_version) < version.parse(known_vulnerability["below"]) + below = parse(package_version) < parse(known_vulnerability["below"]) # Some packages are only vulnerable below a version and not above above = ( - version.parse(package_version) >= version.parse(known_vulnerability["atOrAbove"]) + parse(package_version) >= parse(known_vulnerability["atOrAbove"]) if "atOrAbove" in known_vulnerability else True ) diff --git a/octopoes/bits/runner.py b/octopoes/bits/runner.py index da851076c34..931297ae5df 100644 --- a/octopoes/bits/runner.py +++ b/octopoes/bits/runner.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from importlib import import_module from inspect import isfunction, signature -from typing import Any, Protocol +from typing import Any from bits.definitions import BitDefinition from octopoes.models import OOI @@ -11,15 +11,11 @@ class ModuleException(Exception): """General error for modules""" -class Runnable(Protocol): - def run(self, *args, **kwargs) -> Any: ... - - class BitRunner: def __init__(self, bit_definition: BitDefinition): self.module = bit_definition.module - def run(self, *args, **kwargs) -> list[OOI]: + def run(self, *args: Any, **kwargs: Any) -> list[OOI]: module = import_module(self.module) if not hasattr(module, "run") or not isfunction(module.run): @@ -31,7 +27,7 @@ def run(self, *args, **kwargs) -> list[OOI]: ) return list(module.run(*args, **kwargs)) - def __str__(self): + def __str__(self) -> str: return f"BitRunner {self.module}" diff --git a/octopoes/bits/spf_discovery/spf_discovery.py b/octopoes/bits/spf_discovery/spf_discovery.py index 36a10f16003..a094cc28cfb 100644 --- a/octopoes/bits/spf_discovery/spf_discovery.py +++ b/octopoes/bits/spf_discovery/spf_discovery.py @@ -10,7 +10,7 @@ from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network -def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: DNSTXTRecord, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.value.startswith("v=spf1"): spf_value = input_ooi.value.replace("%(d)", input_ooi.hostname.tokenized.name) parsed = parse(spf_value) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index 62db165bd40..5816d2a2d88 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -443,8 +443,8 @@ def get_scan_profile_inheritance( def list_findings( exclude_muted: bool = True, only_muted: bool = False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), severities: set[RiskLevelSeverity] = Query(DEFAULT_SEVERITY_FILTER), @@ -459,8 +459,8 @@ def list_findings( @router.get("/reports", tags=["Reports"]) def list_reports( - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), ) -> Paginated[tuple[Report, list[Report | None]]]: diff --git a/octopoes/octopoes/connector/__init__.py b/octopoes/octopoes/connector/__init__.py index a5ca0237d37..5f56b2ee59e 100644 --- a/octopoes/octopoes/connector/__init__.py +++ b/octopoes/octopoes/connector/__init__.py @@ -1,12 +1,8 @@ -# Keep for backwards compatibility -from octopoes.models.exception import ObjectNotFoundException - - class ConnectorException(Exception): def __init__(self, value: str): self.value = value - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index d76a03480e0..de2004fe0ce 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -1,5 +1,5 @@ import json -from collections.abc import Sequence, Set +from collections.abc import Iterable, Sequence, Set from datetime import datetime from typing import Literal from uuid import UUID @@ -206,7 +206,7 @@ def save_affirmation(self, affirmation: Affirmation) -> None: self.logger.info("Saved affirmation", affirmation=affirmation, event_code=DECLARATION_CREATED) - def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime): + def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime) -> None: params = {"valid_time": str(valid_time)} self.session.put( f"/{self.client}/scan_profiles", @@ -262,7 +262,7 @@ def count_findings_by_severity(self, valid_time: datetime) -> dict[str, int]: def list_findings( self, - severities: set[RiskLevelSeverity], + severities: Iterable[RiskLevelSeverity], valid_time: datetime, exclude_muted: bool = True, only_muted: bool = False, @@ -301,8 +301,8 @@ def get_report(self, report_id: str) -> Report: return TypeAdapter(Report).validate_json(res.content) - def load_objects_bulk(self, references: set[Reference], valid_time): - params = {"valid_time": valid_time} + def load_objects_bulk(self, references: set[Reference], valid_time: datetime) -> dict[Reference, OOIType]: + params = {"valid_time": str(valid_time)} res = self.session.post( f"/{self.client}/objects/load_bulk", params=params, json=[str(ref) for ref in references] ) diff --git a/octopoes/octopoes/core/app.py b/octopoes/octopoes/core/app.py index 9840f0beaf1..af61d604f11 100644 --- a/octopoes/octopoes/core/app.py +++ b/octopoes/octopoes/core/app.py @@ -20,7 +20,7 @@ def get_xtdb_client(base_uri: str, client: str) -> XTDBHTTPClient: return XTDBHTTPClient(f"{base_uri}/_xtdb", client) -def close_rabbit_channel(queue_uri: str): +def close_rabbit_channel(queue_uri: str) -> None: rabbit_channel = get_rabbit_channel(queue_uri) try: diff --git a/octopoes/octopoes/core/service.py b/octopoes/octopoes/core/service.py index 135ad39f8ba..ce019172c25 100644 --- a/octopoes/octopoes/core/service.py +++ b/octopoes/octopoes/core/service.py @@ -143,7 +143,7 @@ def list_ooi( def get_ooi_tree( self, reference: Reference, valid_time: datetime, search_types: set[type[OOI]] | None = None, depth: int = 1 - ): + ) -> ReferenceTree: tree = self.ooi_repository.get_tree(reference, valid_time, search_types, depth) self._populate_scan_profiles(tree.store.values(), valid_time) return tree @@ -257,7 +257,7 @@ def _run_inference(self, origin: Origin, valid_time: datetime) -> None: logger.exception("Error running inference", exc_info=e) @staticmethod - def check_path_level(path_level: int | None, current_level: int): + def check_path_level(path_level: int | None, current_level: int) -> bool: return path_level is not None and path_level >= current_level def recalculate_scan_profiles(self, valid_time: datetime) -> None: @@ -379,7 +379,7 @@ def recalculate_scan_profiles(self, valid_time: datetime) -> None: ) logger.info("Recalculated scan profiles") - def process_event(self, event: DBEvent): + def process_event(self, event: DBEvent) -> None: # handle event event_handler_name = f"_on_{event.operation_type.value}_{event.entity_type}" handler: Callable[[DBEvent], None] | None = getattr(self, event_handler_name) diff --git a/octopoes/octopoes/models/__init__.py b/octopoes/octopoes/models/__init__.py index 44153335a2c..362b4c0d254 100644 --- a/octopoes/octopoes/models/__init__.py +++ b/octopoes/octopoes/models/__init__.py @@ -43,12 +43,12 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa return core_schema.with_info_after_validator_function(cls.validate, core_schema.str_schema()) @classmethod - def validate(cls, v, info: ValidationInfo): + def validate(cls, v: str, info: ValidationInfo) -> Any: if not isinstance(v, str): raise TypeError("string required") return cls(str(v)) - def __repr__(self): + def __repr__(self) -> str: return f"Reference({super().__repr__()})" @classmethod @@ -124,7 +124,7 @@ class OOI(BaseModel): def model_post_init(self, __context: Any) -> None: # noqa: F841 self.primary_key = self.primary_key or f"{self.get_object_type()}|{self.natural_key}" - def __str__(self): + def __str__(self) -> str: return self.primary_key @classmethod @@ -191,11 +191,11 @@ def get_reverse_relation_name(cls, attr: str) -> str: return cls._reverse_relation_names.get(attr, f"{cls.get_object_type()}_{attr}") @classmethod - def get_tokenized_primary_key(cls, natural_key: str): + def get_tokenized_primary_key(cls, natural_key: str) -> PrimaryKeyToken: token_tree = build_token_tree(cls) natural_key_parts = natural_key.split("|") - def hydrate(node) -> dict | str: + def hydrate(node: dict[str, dict | str]) -> dict | str: for key, value in node.items(): if isinstance(value, dict): node[key] = hydrate(value) @@ -256,10 +256,10 @@ def format_id_short(id_: str) -> str: class PrimaryKeyToken(RootModel): root: dict[str, str | PrimaryKeyToken] - def __getattr__(self, item) -> Any: + def __getattr__(self, item: str) -> Any: return self.root[item] - def __getitem__(self, item) -> Any: + def __getitem__(self, item: str) -> Any: return self.root[item] diff --git a/octopoes/octopoes/models/ooi/dns/records.py b/octopoes/octopoes/models/ooi/dns/records.py index 972b20e4a07..fda89849a0f 100644 --- a/octopoes/octopoes/models/ooi/dns/records.py +++ b/octopoes/octopoes/models/ooi/dns/records.py @@ -144,7 +144,7 @@ class CAATAGS(Enum): ISSUEVMC = "issuevmc" ISSUEMAIL = "issuemail" - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/findings.py b/octopoes/octopoes/models/ooi/findings.py index f7bce1365e9..af82e2de756 100644 --- a/octopoes/octopoes/models/ooi/findings.py +++ b/octopoes/octopoes/models/ooi/findings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from functools import total_ordering from typing import Annotated, Literal @@ -24,10 +26,10 @@ class RiskLevelSeverity(Enum): # unknown = the third party has been contacted, but third party has not determined the risk level (yet) UNKNOWN = "unknown" - def __gt__(self, other: "RiskLevelSeverity") -> bool: + def __gt__(self, other: RiskLevelSeverity) -> bool: return severity_order.index(self.value) > severity_order.index(other.value) - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/network.py b/octopoes/octopoes/models/ooi/network.py index 83094b15dc0..5157796a518 100644 --- a/octopoes/octopoes/models/ooi/network.py +++ b/octopoes/octopoes/models/ooi/network.py @@ -84,7 +84,7 @@ class IPPort(OOI): _information_value = ["protocol", "port"] @classmethod - def format_reference_human_readable(cls, reference: Reference): + def format_reference_human_readable(cls, reference: Reference) -> str: tokenized = reference.tokenized return f"{tokenized.address.address}:{tokenized.port}/{tokenized.protocol}" diff --git a/octopoes/octopoes/models/origin.py b/octopoes/octopoes/models/origin.py index b39ec31e7a2..755fdf98b4c 100644 --- a/octopoes/octopoes/models/origin.py +++ b/octopoes/octopoes/models/origin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from uuid import UUID @@ -21,7 +23,7 @@ class Origin(BaseModel): result: list[Reference] = Field(default_factory=list) task_id: UUID | None = None - def __sub__(self, other) -> set[Reference]: + def __sub__(self, other: Origin) -> set[Reference]: if isinstance(other, Origin): return set(self.result) - set(other.result) else: diff --git a/octopoes/octopoes/models/path.py b/octopoes/octopoes/models/path.py index 399a9945804..650c9e06f1d 100644 --- a/octopoes/octopoes/models/path.py +++ b/octopoes/octopoes/models/path.py @@ -42,7 +42,7 @@ def parse_step(cls, step: str) -> tuple[Direction, str, type[OOI] | None]: raise ValueError(f"Could not parse step: {step}") @classmethod - def calculate_step(cls, source_type: type[OOI], step: str): + def calculate_step(cls, source_type: type[OOI], step: str) -> Segment: direction, property_name, explicit_target_type = cls.parse_step(step) if explicit_target_type: @@ -90,7 +90,7 @@ def __eq__(self, other: object) -> bool: and self.property_name == other.property_name ) - def __str__(self): + def __str__(self) -> str: if self.direction == Direction.INCOMING: if self.target_type is None: raise ValueError("Direction cannot be incoming if target type is None") @@ -99,7 +99,7 @@ def __str__(self): else: return f"{self.property_name}" - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -108,7 +108,7 @@ def __init__(self, segments: list[Segment]): self.segments = segments @classmethod - def parse(cls, path: str): + def parse(cls, path: str) -> Path: start_type, step, *rest = path.split(".") segments = [Segment.calculate_step(type_by_name(start_type), step)] @@ -140,7 +140,7 @@ def __lt__(self, other): def __hash__(self): return hash(str(self)) - def __repr__(self): + def __repr__(self) -> str: return str(self) diff --git a/octopoes/octopoes/models/persistence.py b/octopoes/octopoes/models/persistence.py index 7c28805e807..82ea8ececd0 100644 --- a/octopoes/octopoes/models/persistence.py +++ b/octopoes/octopoes/models/persistence.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -11,7 +11,7 @@ def ReferenceField( *, max_issue_scan_level: int | None = None, max_inherit_scan_level: int | None = None, - **kwargs, + **kwargs: Any, ) -> FieldInfo: if not isinstance(object_type, str): object_type = object_type.get_object_type() diff --git a/octopoes/octopoes/models/tree.py b/octopoes/octopoes/models/tree.py index e6efdc3a3e2..06d7d92a70a 100644 --- a/octopoes/octopoes/models/tree.py +++ b/octopoes/octopoes/models/tree.py @@ -12,7 +12,7 @@ class ReferenceNode(BaseModel): reference: Reference children: dict[str, list[ReferenceNode]] - def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]): + def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]) -> bool: """ Mutable filter function to evict any children from the tree that do not adhere to the provided callback """ diff --git a/octopoes/octopoes/models/types.py b/octopoes/octopoes/models/types.py index 871620eb47d..7d073ce8f2f 100644 --- a/octopoes/octopoes/models/types.py +++ b/octopoes/octopoes/models/types.py @@ -2,6 +2,8 @@ from collections.abc import Iterator +from pydantic.fields import FieldInfo + from octopoes.models import OOI, Reference from octopoes.models.exception import TypeNotFound from octopoes.models.ooi.certificate import ( @@ -213,14 +215,14 @@ def to_concrete(object_types: set[type[OOI]]) -> set[type[OOI]]: return concrete_types -def type_by_name(type_name: str): +def type_by_name(type_name: str) -> type[OOI]: try: return next(t for t in ALL_TYPES if t.__name__ == type_name) except StopIteration: raise TypeNotFound -def related_object_type(field) -> type[OOI]: +def related_object_type(field: FieldInfo) -> type[OOI]: object_type: str | type[OOI] = field.json_schema_extra["object_type"] if isinstance(object_type, str): return type_by_name(object_type) diff --git a/octopoes/octopoes/repositories/ooi_repository.py b/octopoes/octopoes/repositories/ooi_repository.py index b582ffb6abd..7c432b6250a 100644 --- a/octopoes/octopoes/repositories/ooi_repository.py +++ b/octopoes/octopoes/repositories/ooi_repository.py @@ -134,7 +134,16 @@ def count_findings_by_severity(self, valid_time: datetime) -> Counter: raise NotImplementedError def list_findings( - self, severities, valid_time, exclude_muted, only_muted, offset, limit, search_string, order_by, asc_desc + self, + severities: set[RiskLevelSeverity], + valid_time: datetime, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, + search_string: str | None = None, + order_by: Literal["score", "finding_type"] = "score", + asc_desc: Literal["asc", "desc"] = "desc", ) -> Paginated[Finding]: raise NotImplementedError @@ -694,10 +703,10 @@ def list_findings( self, severities: set[RiskLevelSeverity], valid_time: datetime, - exclude_muted=False, - only_muted=False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, search_string: str | None = None, order_by: Literal["score", "finding_type"] = "score", asc_desc: Literal["asc", "desc"] = "desc", diff --git a/octopoes/octopoes/repositories/origin_parameter_repository.py b/octopoes/octopoes/repositories/origin_parameter_repository.py index e6054f5a2e6..d4f335585c7 100644 --- a/octopoes/octopoes/repositories/origin_parameter_repository.py +++ b/octopoes/octopoes/repositories/origin_parameter_repository.py @@ -71,7 +71,7 @@ def list_by_origin(self, origin_id: set[str], valid_time: datetime) -> list[Orig results = self.session.client.query(query, valid_time=valid_time) return [self.deserialize(r[0]) for r in results] - def list_by_reference(self, reference: Reference, valid_time: datetime): + def list_by_reference(self, reference: Reference, valid_time: datetime) -> list[OriginParameter]: query = generate_pull_query( FieldSet.ALL_FIELDS, {"reference": str(reference), "type": OriginParameter.__name__} ) diff --git a/octopoes/octopoes/repositories/scan_profile_repository.py b/octopoes/octopoes/repositories/scan_profile_repository.py index e1b8fcf2bd3..c27954e0009 100644 --- a/octopoes/octopoes/repositories/scan_profile_repository.py +++ b/octopoes/octopoes/repositories/scan_profile_repository.py @@ -50,7 +50,7 @@ def commit(self): self.session.commit() @classmethod - def format_id(cls, ooi_reference: Reference): + def format_id(cls, ooi_reference: Reference) -> str: return f"{cls.object_type}|{ooi_reference}" @classmethod diff --git a/octopoes/octopoes/tasks/tasks.py b/octopoes/octopoes/tasks/tasks.py index 1434993e3c6..b9f734be4ec 100644 --- a/octopoes/octopoes/tasks/tasks.py +++ b/octopoes/octopoes/tasks/tasks.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from logging import config from pathlib import Path +from typing import Any import structlog import yaml @@ -65,7 +66,7 @@ def init_worker(**kwargs): @app.task(queue=QUEUE_NAME_OCTOPOES) -def handle_event(event: dict): +def handle_event(event: dict) -> None: try: parsed_event: DBEvent = TypeAdapter(DBEventType).validate_python(event) @@ -96,7 +97,7 @@ def schedule_scan_profile_recalculations(): @app.task(queue=QUEUE_NAME_OCTOPOES) -def recalculate_scan_profiles(org: str, *args, **kwargs): +def recalculate_scan_profiles(org: str, *args: Any, **kwargs: Any) -> None: session = XTDBSession(get_xtdb_client(str(settings.xtdb_uri), org)) octopoes = bootstrap_octopoes(settings, org, session) diff --git a/octopoes/octopoes/xtdb/client.py b/octopoes/octopoes/xtdb/client.py index 57561c43068..388dae2dfbc 100644 --- a/octopoes/octopoes/xtdb/client.py +++ b/octopoes/octopoes/xtdb/client.py @@ -56,7 +56,7 @@ def _get_xtdb_http_session(base_url: str) -> httpx.Client: class XTDBHTTPClient: - def __init__(self, base_url, client: str): + def __init__(self, base_url: str, client: str): self._client = client self._session = _get_xtdb_http_session(base_url) @@ -173,7 +173,7 @@ def export_transactions(self): self._verify_response(res) return res.json() - def sync(self, timeout: int | None = None): + def sync(self, timeout: int | None = None) -> Any: params = {} if timeout is not None: @@ -198,10 +198,10 @@ def __enter__(self): def __exit__(self, _exc_type: type[Exception], _exc_value: str, _exc_traceback: str) -> None: self.commit() - def add(self, operation: Operation): + def add(self, operation: Operation) -> None: self._operations.append(operation) - def put(self, document: str | dict[str, Any], valid_time: datetime): + def put(self, document: str | dict[str, Any], valid_time: datetime) -> None: self.add((OperationType.PUT, document, valid_time)) def commit(self) -> None: @@ -219,5 +219,5 @@ def commit(self) -> None: logger.info("Called %s callbacks after committing XTDBSession", len(self.post_commit_callbacks)) self.post_commit_callbacks = [] - def listen_post_commit(self, callback: Callable[[], None]): + def listen_post_commit(self, callback: Callable[[], None]) -> None: self.post_commit_callbacks.append(callback) diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 3be2464028f..147c58174a0 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from uuid import UUID, uuid4 @@ -76,13 +78,13 @@ class Query: _offset: int | None = None _order_by: tuple[Aliased, bool] | None = None - def where(self, ooi_type: Ref, **kwargs) -> "Query": + def where(self, ooi_type: Ref, **kwargs: Ref | str | set[str] | bool) -> Query: for field_name, value in kwargs.items(): self._where_field_is(ooi_type, field_name, value) return self - def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> "Query": + def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> Query: """Allows for filtering on multiple values for a specific field.""" for field_name, values in kwargs.items(): @@ -94,7 +96,7 @@ def format(self) -> str: return self._compile(separator="\n ") @classmethod - def from_path(cls, path: Path) -> "Query": + def from_path(cls, path: Path) -> Query: """ Create a query from a Path. @@ -147,14 +149,14 @@ def from_path(cls, path: Path) -> "Query": return query - def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> "Query": + def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> Query: """By default, we pull the target type. But when using find, count, etc., you have to pull explicitly.""" self._find_clauses.append(f"(pull {self._get_object_alias(ooi_type)} {fields})") return self - def find(self, item: Ref, *, index: int | None = None) -> "Query": + def find(self, item: Ref, *, index: int | None = None) -> Query: """Add a find clause, so we can select specific fields in a query to be returned as well.""" if index is None: @@ -164,27 +166,27 @@ def find(self, item: Ref, *, index: int | None = None) -> "Query": return self - def count(self, ooi_type: Ref) -> "Query": + def count(self, ooi_type: Ref) -> Query: self._find_clauses.append(f"(count {self._get_object_alias(ooi_type)})") return self - def limit(self, limit: int) -> "Query": + def limit(self, limit: int) -> Query: self._limit = limit return self - def offset(self, offset: int) -> "Query": + def offset(self, offset: int) -> Query: self._offset = offset return self - def order_by(self, ref: Aliased, ascending: bool = True) -> "Query": + def order_by(self, ref: Aliased, ascending: bool = True) -> Query: self._order_by = (ref, ascending) return self - def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str]) -> None: + def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str] | bool) -> None: """ We need isinstance(value, type) checks to verify value is an OOIType, as issubclass() fails on non-classes: @@ -321,7 +323,7 @@ def _assert_type(self, ref: Ref, ooi_type: type[OOI]) -> str: def _to_object_type_statement(self, ref: Ref, other_type: type[OOI]) -> str: return f'[ {self._get_object_alias(ref)} :object_type "{other_type.get_object_type()}" ]' - def _compile_where_clauses(self, *, separator=" ") -> str: + def _compile_where_clauses(self, *, separator: str = " ") -> str: """Sorted and deduplicated where clauses, since they are both idempotent and commutative""" return separator + separator.join(sorted(set(self._where_clauses))) @@ -329,7 +331,7 @@ def _compile_where_clauses(self, *, separator=" ") -> str: def _compile_find_clauses(self) -> str: return " ".join(self._find_clauses) - def _compile(self, *, separator=" ") -> str: + def _compile(self, *, separator: str = " ") -> str: result_ooi_type = self.result_type.type if isinstance(self.result_type, Aliased) else self.result_type self._where_clauses.append(self._assert_type(self.result_type, result_ooi_type)) @@ -365,7 +367,7 @@ def _get_object_alias(self, object_type: Ref) -> str: def __str__(self) -> str: return self._compile() - def __eq__(self, other: object): + def __eq__(self, other: object) -> bool: if not isinstance(other, Query): return NotImplemented diff --git a/octopoes/octopoes/xtdb/query_builder.py b/octopoes/octopoes/xtdb/query_builder.py index 172cfc745c2..54335f5c8d8 100644 --- a/octopoes/octopoes/xtdb/query_builder.py +++ b/octopoes/octopoes/xtdb/query_builder.py @@ -2,7 +2,8 @@ from collections.abc import Iterable, Mapping from typing import Any -from octopoes.xtdb.related_field_generator import FieldSet, RelatedFieldNode +from octopoes.xtdb import FieldSet +from octopoes.xtdb.related_field_generator import RelatedFieldNode def join_csv(values: Iterable[Any]) -> str: diff --git a/octopoes/octopoes/xtdb/related_field_generator.py b/octopoes/octopoes/xtdb/related_field_generator.py index 1b44112338c..4a26f9a2114 100644 --- a/octopoes/octopoes/xtdb/related_field_generator.py +++ b/octopoes/octopoes/xtdb/related_field_generator.py @@ -51,7 +51,7 @@ def construct_incoming_relations(self): RelatedFieldNode(self.data_model, {foreign_object_type}, self.path + (foreign_key,)) ) - def build_tree(self, depth: int): + def build_tree(self, depth: int) -> None: if depth > 0: self.construct_outgoing_relations() for child_node in self.relations_out.values(): @@ -61,7 +61,7 @@ def build_tree(self, depth: int): for child_node in self.relations_in.values(): child_node.build_tree(depth - 1) - def generate_field(self, field_set: FieldSet, pk_prefix: str): + def generate_field(self, field_set: FieldSet, pk_prefix: str) -> str: queried_fields = pk_prefix if field_set is FieldSet.ONLY_ID else "*" """ Output dicts in XTDB Query Language @@ -105,10 +105,10 @@ def search_nodes(self, search_object_types=set[str]): # Match self return not self.object_types.isdisjoint(search_object_types) - def __repr__(self): + def __repr__(self) -> str: return f"QueryNode[{self}]" - def __str__(self): + def __str__(self) -> str: return ",".join(self.object_types) def __eq__(self, other): diff --git a/pyproject.toml b/pyproject.toml index 021561d838d..7af9b569c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,6 @@ python_version = "3.10" plugins = ["pydantic.mypy"] strict = true -follow_imports = "skip" -warn_unused_ignores = false # This gives false positives in pre-commit as long as we don't enable follow imports disallow_subclassing_any = false disallow_untyped_decorators = false # Needed for FastAPI decorators disallow_any_generics = false @@ -14,8 +12,17 @@ no_implicit_reexport = false warn_return_any = false [[tool.mypy.overrides]] -module = ["httpx.*"] -follow_imports = "normal" +# Following pydantic imports currently gives 2000 errors +module = ["pydantic.*"] +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = ["bytes.*", "cveapi.*"] +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +no_implicit_reexport = true [tool.setuptools_scm] write_to = "_version.py" diff --git a/rocky/account/admin.py b/rocky/account/admin.py index 952d6ff76fc..e5a7e07a8d7 100644 --- a/rocky/account/admin.py +++ b/rocky/account/admin.py @@ -10,7 +10,6 @@ @admin.register(User) class KATUserAdmin(UserAdmin): - model = User list_display = ("email", "is_staff", "is_active") fieldsets = ( (None, {"fields": ("email", "password", "full_name")}), diff --git a/rocky/account/forms/__init__.py b/rocky/account/forms/__init__.py index 448a075559d..7da11d44645 100644 --- a/rocky/account/forms/__init__.py +++ b/rocky/account/forms/__init__.py @@ -11,3 +11,19 @@ from account.forms.login import LoginForm from account.forms.password_reset import PasswordResetForm from account.forms.token import TwoFactorBackupTokenForm, TwoFactorSetupTokenForm, TwoFactorVerifyTokenForm + +__all__ = [ + "AccountTypeSelectForm", + "IndemnificationAddForm", + "MemberRegistrationForm", + "OnboardingOrganizationUpdateForm", + "OrganizationForm", + "OrganizationMemberEditForm", + "OrganizationUpdateForm", + "SetPasswordForm", + "LoginForm", + "PasswordResetForm", + "TwoFactorBackupTokenForm", + "TwoFactorSetupTokenForm", + "TwoFactorVerifyTokenForm", +] diff --git a/rocky/account/forms/organization.py b/rocky/account/forms/organization.py index 64a6b426f60..7bc8b9e72ec 100644 --- a/rocky/account/forms/organization.py +++ b/rocky/account/forms/organization.py @@ -28,11 +28,11 @@ def populate_dropdown_list(self, user): organizations.append([organization.code, organization.name]) if organizations: - props = { - "required": True, - "label": _("Organizations"), - "help_text": _("The organization from which to clone settings."), - "error_messages": self.error_messages, - } - self.fields["organization"] = forms.ChoiceField(**props) + self.fields["organization"] = forms.ChoiceField( + required=True, + label=_("Organizations"), + help_text=_("The organization from which to clone settings."), + error_messages=self.error_messages, + ) + self.fields["organization"].choices = [BLANK_CHOICE] + organizations diff --git a/rocky/account/mixins.py b/rocky/account/mixins.py index 00dde712924..1ec7c87d61c 100644 --- a/rocky/account/mixins.py +++ b/rocky/account/mixins.py @@ -9,6 +9,7 @@ from django.http import Http404 from django.utils.translation import gettext_lazy as _ from django.views import View +from django.views.generic.base import ContextMixin from katalogus.client import KATalogus, get_katalogus from rest_framework.exceptions import ValidationError from rest_framework.request import Request @@ -31,7 +32,7 @@ class OrganizationPermLookupDict: def __init__(self, organization_member, app_label): self.organization_member, self.app_label = organization_member, app_label - def __repr__(self): + def __repr__(self) -> str: return str(self.organization_member.get_all_permissions) def __getitem__(self, perm_name): @@ -50,7 +51,7 @@ class OrganizationPermWrapper: def __init__(self, organization_member): self.organization_member = organization_member - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__qualname__}({self.organization_member!r})" def __getitem__(self, app_label): @@ -71,7 +72,7 @@ def __contains__(self, perm_name): return self[app_label][perm_name] -class OrganizationView(View): +class OrganizationView(ContextMixin, View): def setup(self, request, *args, **kwargs): super().setup(request, *args, **kwargs) diff --git a/rocky/account/models.py b/rocky/account/models.py index 2552a560146..55d7f75322e 100644 --- a/rocky/account/models.py +++ b/rocky/account/models.py @@ -138,7 +138,7 @@ class Meta: EVENT_CODES = {"created": 900111, "updated": 900122, "deleted": 900123} - def __str__(self): + def __str__(self) -> str: return f"{self.name} ({self.user})" def generate_new_token(self) -> str: diff --git a/rocky/account/views/account.py b/rocky/account/views/account.py index c16b78188ed..b287562b073 100644 --- a/rocky/account/views/account.py +++ b/rocky/account/views/account.py @@ -23,7 +23,7 @@ def post(self, request, *args, **kwargs): # Mypy doesn't have the information to understand this return self.get(request, *args, **kwargs) # type: ignore[attr-defined] - def handle_page_action(self, action: str): + def handle_page_action(self, action: str) -> None: if action == PageActions.ACCEPT_CLEARANCE.value: self.organization_member.acknowledged_clearance_level = self.organization_member.trusted_clearance_level elif action == PageActions.WITHDRAW_ACCEPTANCE.value: diff --git a/rocky/crisis_room/views.py b/rocky/crisis_room/views.py index 7a39cd7abcc..4a2f8163882 100644 --- a/rocky/crisis_room/views.py +++ b/rocky/crisis_room/views.py @@ -14,8 +14,7 @@ from octopoes.connector import ConnectorException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models.ooi.findings import RiskLevelSeverity -from rocky.views.mixins import ObservedAtMixin -from rocky.views.ooi_view import ConnectorFormMixin +from rocky.views.mixins import ConnectorFormMixin, ObservedAtMixin logger = structlog.get_logger(__name__) diff --git a/rocky/katalogus/client.py b/rocky/katalogus/client.py index b10dba57d12..d9ace3a80ce 100644 --- a/rocky/katalogus/client.py +++ b/rocky/katalogus/client.py @@ -56,7 +56,7 @@ class Plugin(BaseModel): # make sense out of it: for which organization is this plugin in fact enabled? enabled: bool - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return member.has_perm("tools.can_scan_organization") @@ -73,7 +73,7 @@ class Boefje(Plugin): # use a custom field_serializer for `consumes` @field_serializer("consumes") - def serialize_consumes(self, consumes: set[type[OOI]]): + def serialize_consumes(self, consumes: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in consumes} @field_validator("boefje_schema") @@ -89,7 +89,7 @@ def json_schema_valid(cls, boefje_schema: dict) -> dict | None: return boefje_schema - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return super().can_scan(member) and member.has_clearance_level(self.scan_level.value) @@ -99,7 +99,7 @@ class Normalizer(Plugin): # use a custom field_serializer for `produces` @field_serializer("produces") - def serialize_produces(self, produces: set[type[OOI]]): + def serialize_produces(self, produces: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in produces} diff --git a/rocky/katalogus/forms/plugin_settings.py b/rocky/katalogus/forms/plugin_settings.py index 303f0df4816..22a7b38b282 100644 --- a/rocky/katalogus/forms/plugin_settings.py +++ b/rocky/katalogus/forms/plugin_settings.py @@ -1,3 +1,5 @@ +from typing import Any + from django import forms from django.utils.translation import gettext_lazy as _ from jsonschema.validators import Draft202012Validator @@ -11,7 +13,7 @@ class PluginSchemaForm(forms.Form): error_messages = {"required": _("This field is required.")} - def __init__(self, plugin_schema: dict, values: dict, *args, **kwargs): + def __init__(self, plugin_schema: dict, values: dict, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.plugin_schema = plugin_schema self.values = values diff --git a/rocky/katalogus/views/boefje_setup.py b/rocky/katalogus/views/boefje_setup.py index 6d386a55961..af868f229de 100644 --- a/rocky/katalogus/views/boefje_setup.py +++ b/rocky/katalogus/views/boefje_setup.py @@ -23,8 +23,8 @@ class BoefjeSetupView(OrganizationPermissionRequiredMixin, OrganizationView, For def setup(self, request, *args, **kwargs): super().setup(request, *args, **kwargs) - self.plugin_id = uuid.uuid4() - self.created = str(datetime.now()) + self.plugin_id = str(uuid.uuid4()) + self.created: str | None = str(datetime.now()) self.query_params = urlencode({"new_variant": True}) def get_success_url(self) -> str: @@ -209,7 +209,7 @@ def get_context_data(self, **kwargs): return context -def create_boefje_with_form_data(form_data, plugin_id: str, created: str): +def create_boefje_with_form_data(form_data, plugin_id: str, created: str | None): arguments = [] if not form_data["oci_arguments"] else form_data["oci_arguments"].split() consumes = [] if not form_data["consumes"] else form_data["consumes"].strip("[]").replace("'", "").split(", ") produces = [] if not form_data["produces"] else form_data["produces"].split(",") diff --git a/rocky/katalogus/views/mixins.py b/rocky/katalogus/views/mixins.py index 022cb16ae80..b742bf91f58 100644 --- a/rocky/katalogus/views/mixins.py +++ b/rocky/katalogus/views/mixins.py @@ -1,7 +1,9 @@ +from typing import Any + import structlog from account.mixins import OrganizationView from django.contrib import messages -from django.http import Http404 +from django.http import Http404, HttpRequest from django.shortcuts import redirect from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -17,7 +19,7 @@ class SinglePluginView(OrganizationView): katalogus_client: KATalogus plugin: Plugin - def setup(self, request, *args, plugin_id: str, **kwargs): + def setup(self, request: HttpRequest, *args: Any, plugin_id: str, **kwargs: Any) -> None: """ Prepare organization info and KAT-alogus API client. """ diff --git a/rocky/onboarding/view_helpers.py b/rocky/onboarding/view_helpers.py index 78146742e49..d7f57d41f54 100644 --- a/rocky/onboarding/view_helpers.py +++ b/rocky/onboarding/view_helpers.py @@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _ from reports.views.base import get_selection from tools.models import Organization -from tools.view_helpers import BreadcrumbsMixin, StepsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin, StepsMixin ONBOARDING_PERMISSIONS = ( "tools.can_scan_organization", @@ -77,7 +77,7 @@ def build_steps(self): class OnboardingBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("step_introduction", kwargs={"organization_code": self.organization.code}), diff --git a/rocky/onboarding/views.py b/rocky/onboarding/views.py index 4da8bf5b43e..783fd0599e2 100644 --- a/rocky/onboarding/views.py +++ b/rocky/onboarding/views.py @@ -400,11 +400,8 @@ class OnboardingOrganizationSetupView(PermissionRequiredMixin, IntroductionRegis permission_required = "tools.add_organization" def get(self, request, *args, **kwargs): - members = OrganizationMember.objects.filter(user=self.request.user) - if members: - return redirect( - reverse("step_organization_update", kwargs={"organization_code": members.first().organization.code}) - ) + if member := OrganizationMember.objects.filter(user=self.request.user).first(): + return redirect(reverse("step_organization_update", kwargs={"organization_code": member.organization.code})) return super().get(request, *args, **kwargs) def post(self, request, *args, **kwargs): diff --git a/rocky/reports/forms.py b/rocky/reports/forms.py index e9a6e58815e..7ee01c1ee1d 100644 --- a/rocky/reports/forms.py +++ b/rocky/reports/forms.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +from typing import Any from django import forms from django.utils.translation import gettext_lazy as _ @@ -12,7 +13,7 @@ class OOITypeMultiCheckboxForReportForm(BaseRockyForm): label=_("Filter by OOI types"), required=False, widget=forms.CheckboxSelectMultiple ) - def __init__(self, ooi_types: list[str], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.fields["ooi_type"].choices = ((ooi_type, ooi_type) for ooi_type in ooi_types) @@ -22,7 +23,7 @@ class ReportTypeMultiselectForm(BaseRockyForm): label=_("Report types"), required=False, widget=forms.CheckboxSelectMultiple ) - def __init__(self, report_types: set[Report], *args, **kwargs): + def __init__(self, report_types: set[Report], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) report_types_choices = ((report_type.id, report_type.name) for report_type in report_types) self.fields["report_type"].choices = report_types_choices diff --git a/rocky/reports/report_types/aggregate_organisation_report/report.py b/rocky/reports/report_types/aggregate_organisation_report/report.py index 46f1ae506cf..de7f1fbc61f 100644 --- a/rocky/reports/report_types/aggregate_organisation_report/report.py +++ b/rocky/reports/report_types/aggregate_organisation_report/report.py @@ -2,6 +2,7 @@ from typing import Any import structlog +from django.utils.translation import gettext_lazy as _ from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI @@ -24,7 +25,7 @@ class AggregateOrganisationReport(AggregateReport): id = "aggregate-organisation-report" - name = "Aggregate Organisation Report" + name = _("Aggregate Organisation Report") description = "Aggregate Organisation Report" reports = { "required": [SystemReport], @@ -411,7 +412,9 @@ def is_mail_compliant(result): "config_oois": config_oois, } - def collect_system_specific_data(self, data, services, system_type: str, report_id: str) -> dict[str, Any]: + def collect_system_specific_data( + self, data: dict, services: dict, system_type: str, report_id: str + ) -> dict[str, Any]: """Given a system, return a list of report data from the right sub-reports based on the related report_id""" report_data: dict[str, Any] = {} diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index ec7c4be9ed6..05262d3f3c0 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Any, TypedDict, TypeVar +from django.utils.functional import Promise + from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname @@ -37,8 +39,8 @@ def report_plugins_union(report_types: list[type["BaseReport"]]) -> ReportPlugin class BaseReport: id: str - name: str - description: str + name: Promise + description: Promise template_path: str = "report.html" plugins: ReportPlugins input_ooi_types: set[type[OOI]] diff --git a/rocky/reports/report_types/multi_organization_report/report.py b/rocky/reports/report_types/multi_organization_report/report.py index bc9d66ba5b2..1fbc91a7c4b 100644 --- a/rocky/reports/report_types/multi_organization_report/report.py +++ b/rocky/reports/report_types/multi_organization_report/report.py @@ -255,7 +255,9 @@ def post_process_data(self, data: dict[str, Any]) -> dict[str, Any]: } -def collect_report_data(connector: OctopoesAPIConnector, input_ooi_references: list[str], observed_at: datetime): +def collect_report_data( + connector: OctopoesAPIConnector, input_ooi_references: list[str], observed_at: datetime +) -> dict: report_data = {} for ooi in [x for x in input_ooi_references if Reference.from_str(x).class_type == ReportData]: report_data[ooi] = connector.get(Reference.from_str(ooi), observed_at).model_dump() diff --git a/rocky/reports/report_types/name_server_report/report.py b/rocky/reports/report_types/name_server_report/report.py index 6d0c0a84ec3..cdec60e352c 100644 --- a/rocky/reports/report_types/name_server_report/report.py +++ b/rocky/reports/report_types/name_server_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -37,13 +39,13 @@ def has_dnssec(self): def has_valid_dnssec(self): return sum([check.has_valid_dnssec for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "NameServerChecks"): + def __add__(self, other: NameServerChecks) -> NameServerChecks: return NameServerChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/report_types/web_system_report/report.py b/rocky/reports/report_types/web_system_report/report.py index 20a6d22df70..47f62c760c9 100644 --- a/rocky/reports/report_types/web_system_report/report.py +++ b/rocky/reports/report_types/web_system_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -77,13 +79,13 @@ def certificates_not_expired(self): def certificates_not_expiring_soon(self): return sum([check.certificates_not_expiring_soon for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "WebChecks"): + def __add__(self, other: WebChecks) -> WebChecks: return WebChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/templates/report_overview/report_history_table.html b/rocky/reports/templates/report_overview/report_history_table.html index 9452ad6cb24..cee41f0f203 100644 --- a/rocky/reports/templates/report_overview/report_history_table.html +++ b/rocky/reports/templates/report_overview/report_history_table.html @@ -154,11 +154,11 @@
- {% blocktranslate count counter=report.total_children_reports %} - This report consist of {{counter}} subreport with the following report type and object. - {% plural %} - This report consist of {{counter}} subreports with the following report types and objects. - {% endblocktranslate %} + {% blocktranslate trimmed count counter=report.total_children_reports %} + This report consist of {{ counter }} subreport with the following report type and object. + {% plural %} + This report consist of {{ counter }} subreports with the following report types and objects. + {% endblocktranslate %}