diff --git a/README.md b/README.md index f7077df9f..b9ef624b6 100644 --- a/README.md +++ b/README.md @@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the pydantic dataclass. You must have pydantic as a dependency in your project for this to work. +## Configuration typing imports +By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options: + +### Direct +``` +protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto +``` +this configuration is the default, and will import types as follows: +``` +from typing import ( + List, + Optional, + Union +) +... +value: List[str] = [] +value2: Optional[str] = None +value3: Union[str, int] = 1 +``` +### Root +``` +protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto +``` +this configuration loads the root typing module, and then access the types off of it directly: +``` +import typing +... +value: typing.List[str] = [] +value2: typing.Optional[str] = None +value3: typing.Union[str, int] = 1 +``` + +### 310 +``` +protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto +``` +this configuration avoid loading typing all together if possible and uses the python 3.10 pattern: +``` +... +value: list[str] = [] +value2: str | None = None +value3: str | int = 1 +``` ## Development diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index a973586a0..4221122b9 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -47,6 +47,7 @@ def get_type_reference( package: str, imports: set, source_type: str, + typing_compiler: "TypingCompiler", unwrap: bool = True, pydantic: bool = False, ) -> str: @@ -57,7 +58,7 @@ def get_type_reference( if unwrap: if source_type in WRAPPER_TYPES: wrapped_type = type(WRAPPER_TYPES[source_type]().value) - return f"Optional[{wrapped_type.__name__}]" + return typing_compiler.optional(wrapped_type.__name__) if source_type == ".google.protobuf.Duration": return "timedelta" diff --git a/src/betterproto/plugin/compiler.py b/src/betterproto/plugin/compiler.py index 510d64857..7eee733d0 100644 --- a/src/betterproto/plugin/compiler.py +++ b/src/betterproto/plugin/compiler.py @@ -1,4 +1,7 @@ import os.path +import sys + +from .module_validation import ModuleValidator try: @@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: lstrip_blocks=True, loader=jinja2.FileSystemLoader(templates_folder), ) - template = env.get_template("template.py.j2") + # Load the body first so we have a compleate list of imports needed. + body_template = env.get_template("template.py.j2") + header_template = env.get_template("header.py.j2") - code = template.render(output_file=output_file) + code = body_template.render(output_file=output_file) + code = header_template.render(output_file=output_file) + code code = isort.api.sort_code_string( code=code, show_diff=False, @@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: force_grid_wrap=2, known_third_party=["grpclib", "betterproto"], ) - return black.format_str( + code = black.format_str( src_contents=code, mode=black.Mode(), ) + + # Validate the generated code. + validator = ModuleValidator(iter(code.splitlines())) + if not validator.validate(): + message_builder = ["[WARNING]: Generated code has collisions in the module:"] + for collision, lines in validator.collisions.items(): + message_builder.append(f' "{collision}" on lines:') + for num, line in lines: + message_builder.append(f" {num}:{line}") + print("\n".join(message_builder), file=sys.stderr) + return code diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index a1a1a872f..4102b9675 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -29,10 +29,8 @@ reference to `A` to `B`'s `fields` attribute. """ - import builtins import re -import textwrap from dataclasses import ( dataclass, field, @@ -49,12 +47,6 @@ ) import betterproto -from betterproto import which_one_of -from betterproto.casing import sanitize_name -from betterproto.compile.importing import ( - get_type_reference, - parse_source_type_name, -) from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, @@ -72,6 +64,7 @@ ) from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest +from .. import which_one_of from ..compile.importing import ( get_type_reference, parse_source_type_name, @@ -82,6 +75,10 @@ pythonize_field_name, pythonize_method_name, ) +from .typing_compiler import ( + DirectImportTypingCompiler, + TypingCompiler, +) # Create a unique placeholder to deal with @@ -173,6 +170,7 @@ class ProtoContentBase: """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" source_file: FileDescriptorProto + typing_compiler: TypingCompiler path: List[int] comment_indent: int = 4 parent: Union["betterproto.Message", "OutputTemplate"] @@ -242,7 +240,6 @@ class OutputTemplate: input_files: List[str] = field(default_factory=list) imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) - typing_imports: Set[str] = field(default_factory=set) pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False messages: List["MessageCompiler"] = field(default_factory=list) @@ -251,6 +248,7 @@ class OutputTemplate: imports_type_checking_only: Set[str] = field(default_factory=set) pydantic_dataclasses: bool = False output: bool = True + typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) @property def package(self) -> str: @@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase): """Representation of a protobuf message.""" source_file: FileDescriptorProto + typing_compiler: TypingCompiler parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER @@ -319,7 +318,7 @@ def py_name(self) -> str: @property def annotation(self) -> str: if self.repeated: - return f"List[{self.py_name}]" + return self.typing_compiler.list(self.py_name) return self.py_name @property @@ -434,18 +433,6 @@ def datetime_imports(self) -> Set[str]: imports.add("datetime") return imports - @property - def typing_imports(self) -> Set[str]: - imports = set() - annotation = self.annotation - if "Optional[" in annotation: - imports.add("Optional") - if "List[" in annotation: - imports.add("List") - if "Dict[" in annotation: - imports.add("Dict") - return imports - @property def pydantic_imports(self) -> Set[str]: return set() @@ -458,7 +445,6 @@ def use_builtins(self) -> bool: def add_imports_to(self, output_file: OutputTemplate) -> None: output_file.datetime_imports.update(self.datetime_imports) - output_file.typing_imports.update(self.typing_imports) output_file.pydantic_imports.update(self.pydantic_imports) output_file.builtins_import = output_file.builtins_import or self.use_builtins @@ -488,7 +474,9 @@ def optional(self) -> bool: @property def mutable(self) -> bool: """True if the field is a mutable type, otherwise False.""" - return self.annotation.startswith(("List[", "Dict[")) + return self.annotation.startswith( + ("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[") + ) @property def field_type(self) -> str: @@ -562,6 +550,7 @@ def py_type(self) -> str: package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.type_name, + typing_compiler=self.typing_compiler, pydantic=self.output_file.pydantic_dataclasses, ) else: @@ -573,9 +562,9 @@ def annotation(self) -> str: if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: - return f"List[{py_type}]" + return self.typing_compiler.list(py_type) if self.optional: - return f"Optional[{py_type}]" + return self.typing_compiler.optional(py_type) return py_type @@ -623,11 +612,13 @@ def __post_init__(self) -> None: source_file=self.source_file, parent=self, proto_obj=nested.field[0], # key + typing_compiler=self.typing_compiler, ).py_type self.py_v_type = FieldCompiler( source_file=self.source_file, parent=self, proto_obj=nested.field[1], # value + typing_compiler=self.typing_compiler, ).py_type # Get proto types @@ -645,7 +636,7 @@ def field_type(self) -> str: @property def annotation(self) -> str: - return f"Dict[{self.py_k_type}, {self.py_v_type}]" + return self.typing_compiler.dict(self.py_k_type, self.py_v_type) @property def repeated(self) -> bool: @@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase): def __post_init__(self) -> None: # Add service to output file self.output_file.services.append(self) - self.output_file.typing_imports.add("Dict") super().__post_init__() # check for unset fields @property @@ -725,22 +715,6 @@ def __post_init__(self) -> None: # Add method to service self.parent.methods.append(self) - # Check for imports - if "Optional" in self.py_output_message_type: - self.output_file.typing_imports.add("Optional") - - # Check for Async imports - if self.client_streaming: - self.output_file.typing_imports.add("AsyncIterable") - self.output_file.typing_imports.add("Iterable") - self.output_file.typing_imports.add("Union") - - # Required by both client and server - if self.client_streaming or self.server_streaming: - self.output_file.typing_imports.add("AsyncIterator") - - # add imports required for request arguments timeout, deadline and metadata - self.output_file.typing_imports.add("Optional") self.output_file.imports_type_checking_only.add("import grpclib.server") self.output_file.imports_type_checking_only.add( "from betterproto.grpc.grpclib_client import MetadataLike" @@ -806,6 +780,7 @@ def py_input_message_type(self) -> str: package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.input_type, + typing_compiler=self.output_file.typing_compiler, unwrap=False, pydantic=self.output_file.pydantic_dataclasses, ).strip('"') @@ -835,6 +810,7 @@ def py_output_message_type(self) -> str: package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.output_type, + typing_compiler=self.output_file.typing_compiler, unwrap=False, pydantic=self.output_file.pydantic_dataclasses, ).strip('"') diff --git a/src/betterproto/plugin/module_validation.py b/src/betterproto/plugin/module_validation.py new file mode 100644 index 000000000..4cf05fdca --- /dev/null +++ b/src/betterproto/plugin/module_validation.py @@ -0,0 +1,163 @@ +import re +from collections import defaultdict +from dataclasses import ( + dataclass, + field, +) +from typing import ( + Dict, + Iterator, + List, + Tuple, +) + + +@dataclass +class ModuleValidator: + line_iterator: Iterator[str] + line_number: int = field(init=False, default=0) + + collisions: Dict[str, List[Tuple[int, str]]] = field( + init=False, default_factory=lambda: defaultdict(list) + ) + + def add_import(self, imp: str, number: int, full_line: str): + """ + Adds an import to be tracked. + """ + self.collisions[imp].append((number, full_line)) + + def process_import(self, imp: str): + """ + Filters out the import to its actual value. + """ + if " as " in imp: + imp = imp[imp.index(" as ") + 4 :] + + imp = imp.strip() + assert " " not in imp, imp + return imp + + def evaluate_multiline_import(self, line: str): + """ + Evaluates a multiline import from a starting line + """ + # Filter the first line and remove anything before the import statement. + full_line = line + line = line.split("import", 1)[1] + if "(" in line: + conditional = lambda line: ")" not in line + else: + conditional = lambda line: "\\" in line + + # Remove open parenthesis if it exists. + if "(" in line: + line = line[line.index("(") + 1 :] + + # Choose the conditional based on how multiline imports are formatted. + while conditional(line): + # Split the line by commas + imports = line.split(",") + + for imp in imports: + # Add the import to the namespace + imp = self.process_import(imp) + if imp: + self.add_import(imp, self.line_number, full_line) + # Get the next line + full_line = line = next(self.line_iterator) + # Increment the line number + self.line_number += 1 + + # validate the last line + if ")" in line: + line = line[: line.index(")")] + imports = line.split(",") + for imp in imports: + imp = self.process_import(imp) + if imp: + self.add_import(imp, self.line_number, full_line) + + def evaluate_import(self, line: str): + """ + Extracts an import from a line. + """ + whole_line = line + line = line[line.index("import") + 6 :] + values = line.split(",") + for v in values: + self.add_import(self.process_import(v), self.line_number, whole_line) + + def next(self): + """ + Evaluate each line for names in the module. + """ + line = next(self.line_iterator) + + # Skip lines with indentation or comments + if ( + # Skip indents and whitespace. + line.startswith(" ") + or line == "\n" + or line.startswith("\t") + or + # Skip comments + line.startswith("#") + or + # Skip decorators + line.startswith("@") + ): + self.line_number += 1 + return + + # Skip docstrings. + if line.startswith('"""') or line.startswith("'''"): + quote = line[0] * 3 + line = line[3:] + while quote not in line: + line = next(self.line_iterator) + self.line_number += 1 + return + + # Evaluate Imports. + if line.startswith("from ") or line.startswith("import "): + if "(" in line or "\\" in line: + self.evaluate_multiline_import(line) + else: + self.evaluate_import(line) + + # Evaluate Classes. + elif line.startswith("class "): + class_name = re.search(r"class (\w+)", line).group(1) + if class_name: + self.add_import(class_name, self.line_number, line) + + # Evaluate Functions. + elif line.startswith("def "): + function_name = re.search(r"def (\w+)", line).group(1) + if function_name: + self.add_import(function_name, self.line_number, line) + + # Evaluate direct assignments. + elif "=" in line: + assignment = re.search(r"(\w+)\s*=", line).group(1) + if assignment: + self.add_import(assignment, self.line_number, line) + + self.line_number += 1 + + def validate(self) -> bool: + """ + Run Validation. + """ + try: + while True: + self.next() + except StopIteration: + pass + + # Filter collisions for those with more than one value. + self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1} + + # Return True if no collisions are found. + return not bool(self.collisions) diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index f48533338..2e0d861a3 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -37,6 +37,12 @@ is_map, is_oneof, ) +from .typing_compiler import ( + DirectImportTypingCompiler, + NoTyping310TypingCompiler, + TypingCompiler, + TypingImportTypingCompiler, +) def traverse( @@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: output_package_name ].pydantic_dataclasses = True + # Gather any typing generation options. + typing_opts = [ + opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.") + ] + + if len(typing_opts) > 1: + raise ValueError("Multiple typing options provided") + # Set the compiler type. + typing_opt = typing_opts[0] if typing_opts else "direct" + if typing_opt == "direct": + request_data.output_packages[ + output_package_name + ].typing_compiler = DirectImportTypingCompiler() + elif typing_opt == "root": + request_data.output_packages[ + output_package_name + ].typing_compiler = TypingImportTypingCompiler() + elif typing_opt == "310": + request_data.output_packages[ + output_package_name + ].typing_compiler = NoTyping310TypingCompiler() + # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service @@ -166,6 +194,7 @@ def _make_one_of_field_compiler( parent=parent, proto_obj=proto_obj, path=path, + typing_compiler=output_package.typing_compiler, ) @@ -181,7 +210,11 @@ def read_protobuf_type( return # Process Message message_data = MessageCompiler( - source_file=source_file, parent=output_package, proto_obj=item, path=path + source_file=source_file, + parent=output_package, + proto_obj=item, + path=path, + typing_compiler=output_package.typing_compiler, ) for index, field in enumerate(item.field): if is_map(field, item): @@ -190,6 +223,7 @@ def read_protobuf_type( parent=message_data, proto_obj=field, path=path + [2, index], + typing_compiler=output_package.typing_compiler, ) elif is_oneof(field): _make_one_of_field_compiler( @@ -201,11 +235,16 @@ def read_protobuf_type( parent=message_data, proto_obj=field, path=path + [2, index], + typing_compiler=output_package.typing_compiler, ) elif isinstance(item, EnumDescriptorProto): # Enum EnumDefinitionCompiler( - source_file=source_file, parent=output_package, proto_obj=item, path=path + source_file=source_file, + parent=output_package, + proto_obj=item, + path=path, + typing_compiler=output_package.typing_compiler, ) diff --git a/src/betterproto/plugin/typing_compiler.py b/src/betterproto/plugin/typing_compiler.py new file mode 100644 index 000000000..937c7bfc1 --- /dev/null +++ b/src/betterproto/plugin/typing_compiler.py @@ -0,0 +1,167 @@ +import abc +from collections import defaultdict +from dataclasses import ( + dataclass, + field, +) +from typing import ( + Dict, + Iterator, + Optional, + Set, +) + + +class TypingCompiler(metaclass=abc.ABCMeta): + @abc.abstractmethod + def optional(self, type: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def list(self, type: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def dict(self, key: str, value: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def union(self, *types: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def iterable(self, type: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def async_iterable(self, type: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def async_iterator(self, type: str) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def imports(self) -> Dict[str, Optional[Set[str]]]: + """ + Returns either the direct import as a key with none as value, or a set of + values to import from the key. + """ + raise NotImplementedError() + + def import_lines(self) -> Iterator: + imports = self.imports() + for key, value in imports.items(): + if value is None: + yield f"import {key}" + else: + yield f"from {key} import (" + for v in sorted(value): + yield f" {v}," + yield ")" + + +@dataclass +class DirectImportTypingCompiler(TypingCompiler): + _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + + def optional(self, type: str) -> str: + self._imports["typing"].add("Optional") + return f"Optional[{type}]" + + def list(self, type: str) -> str: + self._imports["typing"].add("List") + return f"List[{type}]" + + def dict(self, key: str, value: str) -> str: + self._imports["typing"].add("Dict") + return f"Dict[{key}, {value}]" + + def union(self, *types: str) -> str: + self._imports["typing"].add("Union") + return f"Union[{', '.join(types)}]" + + def iterable(self, type: str) -> str: + self._imports["typing"].add("Iterable") + return f"Iterable[{type}]" + + def async_iterable(self, type: str) -> str: + self._imports["typing"].add("AsyncIterable") + return f"AsyncIterable[{type}]" + + def async_iterator(self, type: str) -> str: + self._imports["typing"].add("AsyncIterator") + return f"AsyncIterator[{type}]" + + def imports(self) -> Dict[str, Optional[Set[str]]]: + return {k: v if v else None for k, v in self._imports.items()} + + +@dataclass +class TypingImportTypingCompiler(TypingCompiler): + _imported: bool = False + + def optional(self, type: str) -> str: + self._imported = True + return f"typing.Optional[{type}]" + + def list(self, type: str) -> str: + self._imported = True + return f"typing.List[{type}]" + + def dict(self, key: str, value: str) -> str: + self._imported = True + return f"typing.Dict[{key}, {value}]" + + def union(self, *types: str) -> str: + self._imported = True + return f"typing.Union[{', '.join(types)}]" + + def iterable(self, type: str) -> str: + self._imported = True + return f"typing.Iterable[{type}]" + + def async_iterable(self, type: str) -> str: + self._imported = True + return f"typing.AsyncIterable[{type}]" + + def async_iterator(self, type: str) -> str: + self._imported = True + return f"typing.AsyncIterator[{type}]" + + def imports(self) -> Dict[str, Optional[Set[str]]]: + if self._imported: + return {"typing": None} + return {} + + +@dataclass +class NoTyping310TypingCompiler(TypingCompiler): + _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) + + def optional(self, type: str) -> str: + return f"{type} | None" + + def list(self, type: str) -> str: + return f"list[{type}]" + + def dict(self, key: str, value: str) -> str: + return f"dict[{key}, {value}]" + + def union(self, *types: str) -> str: + return " | ".join(types) + + def iterable(self, type: str) -> str: + self._imports["typing"].add("Iterable") + return f"Iterable[{type}]" + + def async_iterable(self, type: str) -> str: + self._imports["typing"].add("AsyncIterable") + return f"AsyncIterable[{type}]" + + def async_iterator(self, type: str) -> str: + self._imports["typing"].add("AsyncIterator") + return f"AsyncIterator[{type}]" + + def imports(self) -> Dict[str, Optional[Set[str]]]: + return {k: v if v else None for k, v in self._imports.items()} diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 new file mode 100644 index 000000000..011eb7d3c --- /dev/null +++ b/src/betterproto/templates/header.py.j2 @@ -0,0 +1,54 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# sources: {{ ', '.join(output_file.input_filenames) }} +# plugin: python-betterproto +# This file has been @generated +{% for i in output_file.python_module_imports|sort %} +import {{ i }} +{% endfor %} +{% set type_checking_imported = False %} + +{% if output_file.pydantic_dataclasses %} +from typing import TYPE_CHECKING +{% set type_checking_imported = True %} + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from pydantic.dataclasses import dataclass +{%- else -%} +from dataclasses import dataclass +{% endif %} + +{% if output_file.datetime_imports %} +from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} + +{% endif%} +{% set typing_imports = output_file.typing_compiler.imports() %} +{% if typing_imports %} +{% for line in output_file.typing_compiler.import_lines() %} +{{ line }} +{% endfor %} +{% endif %} + +{% if output_file.pydantic_imports %} +from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} + +{% endif %} + +import betterproto +{% if output_file.services %} +from betterproto.grpc.grpclib_server import ServiceBase +import grpclib +{% endif %} + +{% for i in output_file.imports|sort %} +{{ i }} +{% endfor %} + +{% if output_file.imports_type_checking_only and not type_checking_imported %} +from typing import TYPE_CHECKING + +if TYPE_CHECKING: +{% for i in output_file.imports_type_checking_only|sort %} {{ i }} +{% endfor %} +{% endif %} diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 5b6715605..13dbce7ac 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -1,53 +1,3 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: {{ ', '.join(output_file.input_filenames) }} -# plugin: python-betterproto -# This file has been @generated -{% for i in output_file.python_module_imports|sort %} -import {{ i }} -{% endfor %} - -{% if output_file.pydantic_dataclasses %} -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from dataclasses import dataclass -else: - from pydantic.dataclasses import dataclass -{%- else -%} -from dataclasses import dataclass -{% endif %} - -{% if output_file.datetime_imports %} -from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} - -{% endif%} -{% if output_file.typing_imports %} -from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} - -{% endif %} - -{% if output_file.pydantic_imports %} -from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} - -{% endif %} - -import betterproto -{% if output_file.services %} -from betterproto.grpc.grpclib_server import ServiceBase -import grpclib -{% endif %} - -{% for i in output_file.imports|sort %} -{{ i }} -{% endfor %} - -{% if output_file.imports_type_checking_only %} -from typing import TYPE_CHECKING - -if TYPE_CHECKING: -{% for i in output_file.imports_type_checking_only|sort %} {{ i }} -{% endfor %} -{% endif %} - {% if output_file.enums %}{% for enum in output_file.enums %} class {{ enum.py_name }}(betterproto.Enum): {% if enum.comment %} @@ -116,14 +66,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] + , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }} {%- endif -%} , * - , timeout: Optional[float] = None - , deadline: Optional["Deadline"] = None - , metadata: Optional["MetadataLike"] = None - ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + , timeout: {{ output_file.typing_compiler.optional("float") }} = None + , deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None + , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None + ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} @@ -191,9 +141,9 @@ class {{ service.py_name }}Base(ServiceBase): {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"] + , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} {%- endif -%} - ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} @@ -225,7 +175,7 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}: return { {% for method in service.methods %} "{{ method.route }}": grpclib.const.Handler( diff --git a/tests/generate.py b/tests/generate.py index 9ce375f85..91bdbb8a2 100755 --- a/tests/generate.py +++ b/tests/generate.py @@ -108,6 +108,7 @@ async def generate_test_case_output( print( f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" ) + print(ref_err.decode()) if verbose: if ref_out: @@ -126,6 +127,7 @@ async def generate_test_case_output( print( f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" ) + print(plg_err.decode()) if verbose: if plg_out: @@ -146,6 +148,7 @@ async def generate_test_case_output( print( f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" ) + print(plg_err_pyd.decode()) if verbose: if plg_out_pyd: diff --git a/tests/test_get_ref_type.py b/tests/test_get_ref_type.py index a0b73a2c5..7b529bd27 100644 --- a/tests/test_get_ref_type.py +++ b/tests/test_get_ref_type.py @@ -4,6 +4,15 @@ get_type_reference, parse_source_type_name, ) +from betterproto.plugin.typing_compiler import DirectImportTypingCompiler + + +@pytest.fixture +def typing_compiler() -> DirectImportTypingCompiler: + """ + Generates a simple Direct Import Typing Compiler for testing. + """ + return DirectImportTypingCompiler() @pytest.mark.parametrize( @@ -32,11 +41,18 @@ ], ) def test_reference_google_wellknown_types_non_wrappers( - google_type: str, expected_name: str, expected_import: str + google_type: str, + expected_name: str, + expected_import: str, + typing_compiler: DirectImportTypingCompiler, ): imports = set() name = get_type_reference( - package="", imports=imports, source_type=google_type, pydantic=False + package="", + imports=imports, + source_type=google_type, + typing_compiler=typing_compiler, + pydantic=False, ) assert name == expected_name @@ -71,11 +87,18 @@ def test_reference_google_wellknown_types_non_wrappers( ], ) def test_reference_google_wellknown_types_non_wrappers_pydantic( - google_type: str, expected_name: str, expected_import: str + google_type: str, + expected_name: str, + expected_import: str, + typing_compiler: DirectImportTypingCompiler, ): imports = set() name = get_type_reference( - package="", imports=imports, source_type=google_type, pydantic=True + package="", + imports=imports, + source_type=google_type, + typing_compiler=typing_compiler, + pydantic=True, ) assert name == expected_name @@ -99,10 +122,15 @@ def test_reference_google_wellknown_types_non_wrappers_pydantic( ], ) def test_referenceing_google_wrappers_unwraps_them( - google_type: str, expected_name: str + google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler ): imports = set() - name = get_type_reference(package="", imports=imports, source_type=google_type) + name = get_type_reference( + package="", + imports=imports, + source_type=google_type, + typing_compiler=typing_compiler, + ) assert name == expected_name assert imports == set() @@ -135,223 +163,321 @@ def test_referenceing_google_wrappers_unwraps_them( ], ) def test_referenceing_google_wrappers_without_unwrapping( - google_type: str, expected_name: str + google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler ): name = get_type_reference( - package="", imports=set(), source_type=google_type, unwrap=False + package="", + imports=set(), + source_type=google_type, + typing_compiler=typing_compiler, + unwrap=False, ) assert name == expected_name -def test_reference_child_package_from_package(): +def test_reference_child_package_from_package( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="package", imports=imports, source_type="package.child.Message" + package="package", + imports=imports, + source_type="package.child.Message", + typing_compiler=typing_compiler, ) assert imports == {"from . import child"} assert name == '"child.Message"' -def test_reference_child_package_from_root(): +def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler): imports = set() - name = get_type_reference(package="", imports=imports, source_type="child.Message") + name = get_type_reference( + package="", + imports=imports, + source_type="child.Message", + typing_compiler=typing_compiler, + ) assert imports == {"from . import child"} assert name == '"child.Message"' -def test_reference_camel_cased(): +def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="", imports=imports, source_type="child_package.example_message" + package="", + imports=imports, + source_type="child_package.example_message", + typing_compiler=typing_compiler, ) assert imports == {"from . import child_package"} assert name == '"child_package.ExampleMessage"' -def test_reference_nested_child_from_root(): +def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="", imports=imports, source_type="nested.child.Message" + package="", + imports=imports, + source_type="nested.child.Message", + typing_compiler=typing_compiler, ) assert imports == {"from .nested import child as nested_child"} assert name == '"nested_child.Message"' -def test_reference_deeply_nested_child_from_root(): +def test_reference_deeply_nested_child_from_root( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="", imports=imports, source_type="deeply.nested.child.Message" + package="", + imports=imports, + source_type="deeply.nested.child.Message", + typing_compiler=typing_compiler, ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} assert name == '"deeply_nested_child.Message"' -def test_reference_deeply_nested_child_from_package(): +def test_reference_deeply_nested_child_from_package( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( package="package", imports=imports, source_type="package.deeply.nested.child.Message", + typing_compiler=typing_compiler, ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} assert name == '"deeply_nested_child.Message"' -def test_reference_root_sibling(): +def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler): imports = set() - name = get_type_reference(package="", imports=imports, source_type="Message") + name = get_type_reference( + package="", + imports=imports, + source_type="Message", + typing_compiler=typing_compiler, + ) assert imports == set() assert name == '"Message"' -def test_reference_nested_siblings(): +def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler): imports = set() - name = get_type_reference(package="foo", imports=imports, source_type="foo.Message") + name = get_type_reference( + package="foo", + imports=imports, + source_type="foo.Message", + typing_compiler=typing_compiler, + ) assert imports == set() assert name == '"Message"' -def test_reference_deeply_nested_siblings(): +def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="foo.bar", imports=imports, source_type="foo.bar.Message" + package="foo.bar", + imports=imports, + source_type="foo.bar.Message", + typing_compiler=typing_compiler, ) assert imports == set() assert name == '"Message"' -def test_reference_parent_package_from_child(): +def test_reference_parent_package_from_child( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="package.child", imports=imports, source_type="package.Message" + package="package.child", + imports=imports, + source_type="package.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ... import package as __package__"} assert name == '"__package__.Message"' -def test_reference_parent_package_from_deeply_nested_child(): +def test_reference_parent_package_from_deeply_nested_child( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( package="package.deeply.nested.child", imports=imports, source_type="package.deeply.nested.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ... import nested as __nested__"} assert name == '"__nested__.Message"' -def test_reference_ancestor_package_from_nested_child(): +def test_reference_ancestor_package_from_nested_child( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( package="package.ancestor.nested.child", imports=imports, source_type="package.ancestor.Message", + typing_compiler=typing_compiler, ) assert imports == {"from .... import ancestor as ___ancestor__"} assert name == '"___ancestor__.Message"' -def test_reference_root_package_from_child(): +def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="package.child", imports=imports, source_type="Message" + package="package.child", + imports=imports, + source_type="Message", + typing_compiler=typing_compiler, ) assert imports == {"from ... import Message as __Message__"} assert name == '"__Message__"' -def test_reference_root_package_from_deeply_nested_child(): +def test_reference_root_package_from_deeply_nested_child( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="package.deeply.nested.child", imports=imports, source_type="Message" + package="package.deeply.nested.child", + imports=imports, + source_type="Message", + typing_compiler=typing_compiler, ) assert imports == {"from ..... import Message as ____Message__"} assert name == '"____Message__"' -def test_reference_unrelated_package(): +def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler): imports = set() - name = get_type_reference(package="a", imports=imports, source_type="p.Message") + name = get_type_reference( + package="a", + imports=imports, + source_type="p.Message", + typing_compiler=typing_compiler, + ) assert imports == {"from .. import p as _p__"} assert name == '"_p__.Message"' -def test_reference_unrelated_nested_package(): +def test_reference_unrelated_nested_package( + typing_compiler: DirectImportTypingCompiler, +): imports = set() - name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") + name = get_type_reference( + package="a.b", + imports=imports, + source_type="p.q.Message", + typing_compiler=typing_compiler, + ) assert imports == {"from ...p import q as __p_q__"} assert name == '"__p_q__.Message"' -def test_reference_unrelated_deeply_nested_package(): +def test_reference_unrelated_deeply_nested_package( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" + package="a.b.c.d", + imports=imports, + source_type="p.q.r.s.Message", + typing_compiler=typing_compiler, ) assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} assert name == '"____p_q_r_s__.Message"' -def test_reference_cousin_package(): +def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler): imports = set() - name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") + name = get_type_reference( + package="a.x", + imports=imports, + source_type="a.y.Message", + typing_compiler=typing_compiler, + ) assert imports == {"from .. import y as _y__"} assert name == '"_y__.Message"' -def test_reference_cousin_package_different_name(): +def test_reference_cousin_package_different_name( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="test.package1", imports=imports, source_type="cousin.package2.Message" + package="test.package1", + imports=imports, + source_type="cousin.package2.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ...cousin import package2 as __cousin_package2__"} assert name == '"__cousin_package2__.Message"' -def test_reference_cousin_package_same_name(): +def test_reference_cousin_package_same_name( + typing_compiler: DirectImportTypingCompiler, +): imports = set() name = get_type_reference( - package="test.package", imports=imports, source_type="cousin.package.Message" + package="test.package", + imports=imports, + source_type="cousin.package.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ...cousin import package as __cousin_package__"} assert name == '"__cousin_package__.Message"' -def test_reference_far_cousin_package(): +def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="a.x.y", imports=imports, source_type="a.b.c.Message" + package="a.x.y", + imports=imports, + source_type="a.b.c.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ...b import c as __b_c__"} assert name == '"__b_c__.Message"' -def test_reference_far_far_cousin_package(): +def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler): imports = set() name = get_type_reference( - package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" + package="a.x.y.z", + imports=imports, + source_type="a.b.c.d.Message", + typing_compiler=typing_compiler, ) assert imports == {"from ....b.c import d as ___b_c_d__"} diff --git a/tests/test_module_validation.py b/tests/test_module_validation.py new file mode 100644 index 000000000..9cae272bb --- /dev/null +++ b/tests/test_module_validation.py @@ -0,0 +1,111 @@ +from typing import ( + List, + Optional, + Set, +) + +import pytest + +from betterproto.plugin.module_validation import ModuleValidator + + +@pytest.mark.parametrize( + ["text", "expected_collisions"], + [ + pytest.param( + ["import os"], + None, + id="single import", + ), + pytest.param( + ["import os", "import sys"], + None, + id="multiple imports", + ), + pytest.param( + ["import os", "import os"], + {"os"}, + id="duplicate imports", + ), + pytest.param( + ["from os import path", "import os"], + None, + id="duplicate imports with alias", + ), + pytest.param( + ["from os import path", "import os as os_alias"], + None, + id="duplicate imports with alias", + ), + pytest.param( + ["from os import path", "import os as path"], + {"path"}, + id="duplicate imports with alias", + ), + pytest.param( + ["import os", "class os:"], + {"os"}, + id="duplicate import with class", + ), + pytest.param( + ["import os", "class os:", " pass", "import sys"], + {"os"}, + id="duplicate import with class and another", + ), + pytest.param( + ["def test(): pass", "class test:"], + {"test"}, + id="duplicate class and function", + ), + pytest.param( + ["def test(): pass", "def test(): pass"], + {"test"}, + id="duplicate functions", + ), + pytest.param( + ["def test(): pass", "test = 100"], + {"test"}, + id="function and variable", + ), + pytest.param( + ["def test():", " test = 3"], + None, + id="function and variable in function", + ), + pytest.param( + [ + "def test(): pass", + "'''", + "def test(): pass", + "'''", + "def test_2(): pass", + ], + None, + id="duplicate functions with multiline string", + ), + pytest.param( + ["def test(): pass", "# def test(): pass"], + None, + id="duplicate functions with comments", + ), + pytest.param( + ["from test import (", " A", " B", " C", ")"], + None, + id="multiline import", + ), + pytest.param( + ["from test import (", " A", " B", " C", ")", "from test import A"], + {"A"}, + id="multiline import with duplicate", + ), + ], +) +def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]): + line_iterator = iter(text) + validator = ModuleValidator(line_iterator) + valid = validator.validate() + if expected_collisions is None: + assert valid + else: + assert set(validator.collisions.keys()) == expected_collisions + assert not valid diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py new file mode 100644 index 000000000..3d1083c72 --- /dev/null +++ b/tests/test_typing_compiler.py @@ -0,0 +1,80 @@ +import pytest + +from betterproto.plugin.typing_compiler import ( + DirectImportTypingCompiler, + NoTyping310TypingCompiler, + TypingImportTypingCompiler, +) + + +def test_direct_import_typing_compiler(): + compiler = DirectImportTypingCompiler() + assert compiler.imports() == {} + assert compiler.optional("str") == "Optional[str]" + assert compiler.imports() == {"typing": {"Optional"}} + assert compiler.list("str") == "List[str]" + assert compiler.imports() == {"typing": {"Optional", "List"}} + assert compiler.dict("str", "int") == "Dict[str, int]" + assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}} + assert compiler.union("str", "int") == "Union[str, int]" + assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}} + assert compiler.iterable("str") == "Iterable[str]" + assert compiler.imports() == { + "typing": {"Optional", "List", "Dict", "Union", "Iterable"} + } + assert compiler.async_iterable("str") == "AsyncIterable[str]" + assert compiler.imports() == { + "typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"} + } + assert compiler.async_iterator("str") == "AsyncIterator[str]" + assert compiler.imports() == { + "typing": { + "Optional", + "List", + "Dict", + "Union", + "Iterable", + "AsyncIterable", + "AsyncIterator", + } + } + + +def test_typing_import_typing_compiler(): + compiler = TypingImportTypingCompiler() + assert compiler.imports() == {} + assert compiler.optional("str") == "typing.Optional[str]" + assert compiler.imports() == {"typing": None} + assert compiler.list("str") == "typing.List[str]" + assert compiler.imports() == {"typing": None} + assert compiler.dict("str", "int") == "typing.Dict[str, int]" + assert compiler.imports() == {"typing": None} + assert compiler.union("str", "int") == "typing.Union[str, int]" + assert compiler.imports() == {"typing": None} + assert compiler.iterable("str") == "typing.Iterable[str]" + assert compiler.imports() == {"typing": None} + assert compiler.async_iterable("str") == "typing.AsyncIterable[str]" + assert compiler.imports() == {"typing": None} + assert compiler.async_iterator("str") == "typing.AsyncIterator[str]" + assert compiler.imports() == {"typing": None} + + +def test_no_typing_311_typing_compiler(): + compiler = NoTyping310TypingCompiler() + assert compiler.imports() == {} + assert compiler.optional("str") == "str | None" + assert compiler.imports() == {} + assert compiler.list("str") == "list[str]" + assert compiler.imports() == {} + assert compiler.dict("str", "int") == "dict[str, int]" + assert compiler.imports() == {} + assert compiler.union("str", "int") == "str | int" + assert compiler.imports() == {} + assert compiler.iterable("str") == "Iterable[str]" + assert compiler.imports() == {"typing": {"Iterable"}} + assert compiler.async_iterable("str") == "AsyncIterable[str]" + assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}} + assert compiler.async_iterator("str") == "AsyncIterator[str]" + assert compiler.imports() == { + "typing": {"Iterable", "AsyncIterable", "AsyncIterator"} + }