diff --git a/src/codemodder/registry.py b/src/codemodder/registry.py index 2f4af598..971ac5eb 100644 --- a/src/codemodder/registry.py +++ b/src/codemodder/registry.py @@ -2,6 +2,7 @@ import os import re +from collections import defaultdict from dataclasses import dataclass from importlib.metadata import EntryPoint, entry_points from itertools import chain @@ -36,6 +37,7 @@ class CodemodRegistry: def __init__(self): self._codemods_by_id = {} + self._codemods_by_tool = defaultdict(list) self._default_include_paths = set() @property @@ -50,6 +52,9 @@ def codemods(self): def default_include_paths(self) -> list[str]: return list(self._default_include_paths) + def codemods_by_tool(self, tool_name: str) -> list[BaseCodemod]: + return self._codemods_by_tool.get(tool_name, []) + def add_codemod_collection(self, collection: CodemodCollection): for codemod in collection.codemods: wrapper = codemod() if isinstance(codemod, type) else codemod @@ -59,6 +64,7 @@ def add_codemod_collection(self, collection: CodemodCollection): ) self._codemods_by_id[wrapper.id] = wrapper + self._codemods_by_tool[collection.origin].append(wrapper) self._default_include_paths.update( chain( *[ diff --git a/tests/test_registry.py b/tests/test_registry.py index 648b7a95..b846a9dc 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,4 +1,8 @@ -from codemodder.registry import CodemodCollection, CodemodRegistry +from codemodder.registry import ( + CodemodCollection, + CodemodRegistry, + load_registered_codemods, +) def test_default_extensions(mocker): @@ -18,3 +22,24 @@ def test_default_extensions(mocker): "*.py", "*.txt", ] + + +def test_codemods_by_tool(mocker): + registry = CodemodRegistry() + assert not registry._codemods_by_tool + + CodemodA = mocker.MagicMock() + CodemodB = mocker.MagicMock() + + registry.add_codemod_collection( + CodemodCollection(origin="origin", codemods=[CodemodA, CodemodB]) + ) + + assert len(registry.codemods_by_tool("origin")) == 2 + + +def test_current_codemods_by_tool(): + codemod_registry = load_registered_codemods() + assert len(codemod_registry.codemods_by_tool("sonar")) > 0 + assert len(codemod_registry.codemods_by_tool("semgrep")) > 0 + assert len(codemod_registry.codemods_by_tool("pixee")) > 0