Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Kotlin and C++ support #969

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions patchwork/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import signal
import traceback
from collections import deque
from contextlib import nullcontext
from pathlib import Path
from typing import Any
from contextlib import nullcontext

import click
import yaml
Expand Down Expand Up @@ -130,7 +130,7 @@ def sigint_handler(signum, frame):
],
case_sensitive=False,
),
is_eager=True
is_eager=True,
)
@click.argument("patchflow", nargs=1, required=True)
@click.argument("opts", nargs=-1, type=click.UNPROCESSED, required=False)
Expand All @@ -154,7 +154,7 @@ def cli(
data_format: str,
patched_api_key: str | None,
disable_telemetry: bool,
debug: bool
debug: bool,
):
setup_cli()

Expand All @@ -169,7 +169,7 @@ def cli(
possbile_module_paths = deque((module_path,))

panel = logger.panel("Initializing Patchwork CLI") if debug else nullcontext()

with panel:
inputs = {}
if patched_api_key is not None:
Expand Down Expand Up @@ -227,24 +227,24 @@ def cli(
else:
# treat --key=value as a key-value pair
inputs[key] = value
patchflow_panel = nullcontext() if debug else logger.panel(f"Patchflow {patchflow} inputs")

patchflow_panel = nullcontext() if debug else logger.panel(f"Patchflow {patchflow} inputs")

with patchflow_panel as _:
if debug is True:
logger.info("DEBUGGING ENABLED. INPUTS WILL BE SHOWN BEFORE EACH STEP BEFORE PROCEEDING TO RUN IT.")
try:
patched = PatchedClient(inputs.get("patched_api_key"))
if not disable_telemetry:
patched.send_public_telemetry(patchflow_name, inputs)

with patched.patched_telemetry(patchflow_name, {}):
patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
except Exception as e:
logger.debug(traceback.format_exc())
logger.error(f"Error running patchflow {patchflow}: {e}")
exit(1)
if debug is True:
logger.info("DEBUGGING ENABLED. INPUTS WILL BE SHOWN BEFORE EACH STEP BEFORE PROCEEDING TO RUN IT.")
try:
patched = PatchedClient(inputs.get("patched_api_key"))
if not disable_telemetry:
patched.send_public_telemetry(patchflow_name, inputs)

with patched.patched_telemetry(patchflow_name, {}):
patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
except Exception as e:
logger.debug(traceback.format_exc())
logger.error(f"Error running patchflow {patchflow}: {e}")
exit(1)

if output is not None:
serialize = _DATA_FORMAT_MAPPING.get(data_format, json.dumps)
Expand Down
21 changes: 20 additions & 1 deletion patchwork/common/context_strategy/context_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .cpp import CppBlockStrategy, CppClassStrategy, CppMethodStrategy
from .generic import FullFileStrategy, NoopStrategy
from .java import JavaBlockStrategy, JavaClassStrategy, JavaMethodStrategy
from .javascript import (
Expand All @@ -10,6 +11,7 @@
JsxClassStrategy,
JsxFunctionStrategy,
)
from .kotlin import KotlinClassStrategy, KotlinMethodStrategy
from .protocol import ContextStrategyProtocol
from .python import PythonBlockStrategy, PythonFunctionStrategy

Expand All @@ -24,6 +26,14 @@ class ContextStrategies:
JAVA_CLASS = "JAVA_CLASS"
JAVA_METHOD = "JAVA_METHOD"
JAVA_BLOCK = "JAVA_BLOCK"
# Cpp strategies
CPP_CLASS = "CPP_CLASS"
CPP_METHOD = "CPP_METHOD"
CPP_BLOCK = "CPP_BLOCK"
# Java strategies
KOTLIN_CLASS = "KOTLIN_CLASS"
KOTLIN_METHOD = "KOTLIN_METHOD"
KOTLIN_BLOCK = "KOTLIN_BLOCK"
# JavaScript strategies
JAVASCRIPT_CLASS = "JAVASCRIPT_CLASS"
JAVASCRIPT_FUNCTION = "JAVASCRIPT_FUNCTION"
Expand All @@ -35,19 +45,23 @@ class ContextStrategies:

PYTHON_PARTIAL_STRATEGIES = [PYTHON_FUNCTION, PYTHON_BLOCK]
JAVA_PARTIAL_STRATEGIES = [JAVA_CLASS, JAVA_METHOD, JAVA_BLOCK]
CPP_PARTIAL_STRATEGIES = [CPP_CLASS, CPP_METHOD, CPP_BLOCK]
KOTLIN_PARTIAL_STRATEGIES = [KOTLIN_CLASS, KOTLIN_METHOD, KOTLIN_BLOCK]
JAVASCRIPT_PARTIAL_STRATEGIES = [JAVASCRIPT_CLASS, JAVASCRIPT_FUNCTION, JAVASCRIPT_BLOCK]
JSX_PARTIAL_STRATEGIES = [JSX_CLASS, JSX_FUNCTION, JSX_BLOCK]

ALL = [
FULL_FILE,
*PYTHON_PARTIAL_STRATEGIES,
*JAVA_PARTIAL_STRATEGIES,
*CPP_PARTIAL_STRATEGIES,
*KOTLIN_PARTIAL_STRATEGIES,
*JAVASCRIPT_PARTIAL_STRATEGIES,
*JSX_PARTIAL_STRATEGIES,
NOOP,
]

FUNCTION = [PYTHON_FUNCTION, JAVA_METHOD, JAVASCRIPT_FUNCTION, JSX_FUNCTION]
FUNCTION = [PYTHON_FUNCTION, JAVA_METHOD, CPP_METHOD, KOTLIN_METHOD, JAVASCRIPT_FUNCTION, JSX_FUNCTION]

__MAPPING: dict[str, ContextStrategyProtocol] = {
FULL_FILE: FullFileStrategy(),
Expand All @@ -57,6 +71,11 @@ class ContextStrategies:
JAVA_CLASS: JavaClassStrategy(),
JAVA_METHOD: JavaMethodStrategy(),
JAVA_BLOCK: JavaBlockStrategy(),
CPP_CLASS: CppClassStrategy(),
CPP_METHOD: CppMethodStrategy(),
CPP_BLOCK: CppBlockStrategy(),
KOTLIN_CLASS: KotlinClassStrategy(),
KOTLIN_METHOD: KotlinMethodStrategy(),
JAVASCRIPT_CLASS: JavascriptClassStrategy(),
JAVASCRIPT_FUNCTION: JavascriptFunctionStrategy(),
JAVASCRIPT_BLOCK: JavascriptBlockStrategy(),
Expand Down
85 changes: 85 additions & 0 deletions patchwork/common/context_strategy/cpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from patchwork.common.context_strategy.langugues import CppLanguage
from patchwork.common.context_strategy.protocol import TreeSitterStrategy


class CppStrategy(TreeSitterStrategy):
def __init__(self, query: str):
"""
Initialize the JavaSearcher instance.
CTY-git marked this conversation as resolved.
Show resolved Hide resolved

Args:
query (str): The search query string to be used for Java file search.
"""

# exts from https://gcc.gnu.org/onlinedocs/gcc-4.4.1/gcc/Overall-Options.html#index-file-name-suffix-71
exts = [
".ii",
".h",
".cc",
".cp",
".cxx",
".cpp",
".CPP",
".c++",
".C",
".hh",
".H",
".hp",
".hxx",
".hpp",
".HPP",
".h++",
".tcc",
]
super().__init__("cpp", query, exts, CppLanguage())
self.query = query


class CppClassStrategy(CppStrategy):
def __init__(self):
"""
Initialize the current class by calling the parent class's __init__ method.
The specific class to be initialized should have a class_declaration marked by @node.
"""
super().__init__(
"""
(class_specifier) @node
""".strip()
)


class CppMethodStrategy(CppStrategy):
def __init__(self):
"""
Initialize the newly created object by inheriting properties and
methods from the parent class.

Parameters:
- self: instance of the class

Returns:
- None
"""
super().__init__(
"""
[
(comment) @comment
(function_definition) @node
]
""".strip()
)


class CppBlockStrategy(CppStrategy):
def __init__(self):
"""
Initialize the class by calling the parent class's constructor.

Parameters:
- self: The object instance.
"""
super().__init__(
"""
(compound_statement) @node
""".strip()
)
49 changes: 49 additions & 0 deletions patchwork/common/context_strategy/kotlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from patchwork.common.context_strategy.langugues import JavaLanguage
from patchwork.common.context_strategy.protocol import TreeSitterStrategy


class KotlinStrategy(TreeSitterStrategy):
def __init__(self, query: str):
"""
Initialize the JavaSearcher instance.

Args:
query (str): The search query string to be used for Java file search.
"""
super().__init__("kotlin", query, ["kt"], JavaLanguage())
self.query = query


class KotlinClassStrategy(KotlinStrategy):
def __init__(self):
"""
Initialize the current class by calling the parent class's __init__ method.
The specific class to be initialized should have a class_declaration marked by @node.
"""
super().__init__(
"""
(class_declaration) @node
""".strip()
)


class KotlinMethodStrategy(KotlinStrategy):
def __init__(self):
"""
Initialize the newly created object by inheriting properties and
methods from the parent class.

Parameters:
- self: instance of the class

Returns:
- None
"""
super().__init__(
"""
[
(multiline_comment) @comment
(function_declaration) @node
]
""".strip()
)
13 changes: 8 additions & 5 deletions patchwork/common/context_strategy/langugues.py
CTY-git marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def __init__(self):
"""
self._comment_format = """\
/**
* <Method description>
*
* @param <Parameter name> <Parameter description>
* @return <Return description>
*/
* <Method description>
*
* @param <Parameter name> <Parameter description>
* @return <Return description>
*/
"""

@property
Expand All @@ -56,6 +56,9 @@ def docstring_format(self) -> str:
return self._comment_format


CppLanguage = JavaLanguage
CTY-git marked this conversation as resolved.
Show resolved Hide resolved


class PythonLanguage(LanguageProtocol):
def __init__(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion patchwork/common/context_strategy/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def query_src(self, src: list[str]):
"""
language = get_language(self.tree_sitter_language)
parser = get_parser(self.tree_sitter_language)
tree = parser.parse("".join(src).encode("utf-8-sig"))
tree = parser.parse("".join(src).encode())
return language.query(self.query).captures(tree.root_node)

def get_contexts(self, src: list[str]) -> list[Position]:
Expand Down
18 changes: 13 additions & 5 deletions patchwork/common/utils/input_parsing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
from typing_extensions import Union, AnyStr

from collections.abc import Iterable, Mapping

from typing_extensions import AnyStr, Union

__ITEM_TYPE = Union[AnyStr, Mapping]


def __parse_to_list_handle_str(input_value: AnyStr, possible_delimiters: Iterable[AnyStr | None]) -> list[str]:
for possible_delimiter in possible_delimiters:
if possible_delimiter is None:
Expand All @@ -14,14 +17,18 @@ def __parse_to_list_handle_str(input_value: AnyStr, possible_delimiters: Iterabl

return []


def __parse_to_list_handle_dict(input_value: Mapping, possible_keys: Iterable[AnyStr | None]) -> list[str]:
for possible_key in possible_keys:
if input_value.get(possible_key) is not None:
return input_value.get(possible_key)

return []

def __parse_to_list_handle_iterable(input_value: Iterable[__ITEM_TYPE], possible_keys: Iterable[AnyStr | None]) -> list[str]:

def __parse_to_list_handle_iterable(
input_value: Iterable[__ITEM_TYPE], possible_keys: Iterable[AnyStr | None]
) -> list[str]:
rv = []
for item in input_value:
if isinstance(item, dict):
Expand All @@ -33,10 +40,11 @@ def __parse_to_list_handle_iterable(input_value: Iterable[__ITEM_TYPE], possible

return rv


def parse_to_list(
input_value: __ITEM_TYPE | Iterable[__ITEM_TYPE],
possible_delimiters: Iterable[AnyStr | None] | None = None ,
possible_keys: Iterable[AnyStr | None] | None = None
input_value: __ITEM_TYPE | Iterable[__ITEM_TYPE],
possible_delimiters: Iterable[AnyStr | None] | None = None,
possible_keys: Iterable[AnyStr | None] | None = None,
) -> list[str]:
if len(input_value) < 1:
return []
Expand Down
Loading
Loading