Skip to content

Commit

Permalink
be able to get codemods by tool from registry (#920)
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna authored Nov 18, 2024
1 parent 2cf33fd commit f00d4c0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/codemodder/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import re
from collections import defaultdict
from dataclasses import dataclass
from importlib.metadata import EntryPoint, entry_points
from itertools import chain
Expand Down Expand Up @@ -36,6 +37,7 @@ class CodemodRegistry:

def __init__(self):
self._codemods_by_id = {}
self._codemods_by_tool = defaultdict(list)
self._default_include_paths = set()

@property
Expand All @@ -50,6 +52,9 @@ def codemods(self):
def default_include_paths(self) -> list[str]:
return list(self._default_include_paths)

def codemods_by_tool(self, tool_name: str) -> list[BaseCodemod]:
return self._codemods_by_tool.get(tool_name, [])

def add_codemod_collection(self, collection: CodemodCollection):
for codemod in collection.codemods:
wrapper = codemod() if isinstance(codemod, type) else codemod
Expand All @@ -59,6 +64,7 @@ def add_codemod_collection(self, collection: CodemodCollection):
)

self._codemods_by_id[wrapper.id] = wrapper
self._codemods_by_tool[collection.origin].append(wrapper)
self._default_include_paths.update(
chain(
*[
Expand Down
27 changes: 26 additions & 1 deletion tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from codemodder.registry import CodemodCollection, CodemodRegistry
from codemodder.registry import (
CodemodCollection,
CodemodRegistry,
load_registered_codemods,
)


def test_default_extensions(mocker):
Expand All @@ -18,3 +22,24 @@ def test_default_extensions(mocker):
"*.py",
"*.txt",
]


def test_codemods_by_tool(mocker):
registry = CodemodRegistry()
assert not registry._codemods_by_tool

CodemodA = mocker.MagicMock()
CodemodB = mocker.MagicMock()

registry.add_codemod_collection(
CodemodCollection(origin="origin", codemods=[CodemodA, CodemodB])
)

assert len(registry.codemods_by_tool("origin")) == 2


def test_current_codemods_by_tool():
codemod_registry = load_registered_codemods()
assert len(codemod_registry.codemods_by_tool("sonar")) > 0
assert len(codemod_registry.codemods_by_tool("semgrep")) > 0
assert len(codemod_registry.codemods_by_tool("pixee")) > 0

0 comments on commit f00d4c0

Please sign in to comment.