Skip to content

Commit

Permalink
Fix codemod and add more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Jul 16, 2024
1 parent e554bd8 commit 8e1fe7d
Show file tree
Hide file tree
Showing 6 changed files with 663 additions and 558 deletions.
1,067 changes: 514 additions & 553 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ asgiref = "^3.2"
ddtrace = ">=1.6.4"
email-validator = {version = ">=1.1.3,<3.0.0", optional = false}
freezegun = "^1.2.1"
libcst = {version = ">=0.4.7", optional = false}
libcst = {version = ">=1.0.0", optional = false}
MarkupSafe = "2.1.3"
nox = "^2023.4.22"
nox-poetry = "^1.0.3"
Expand Down
2 changes: 1 addition & 1 deletion strawberry/cli/commands/upgrade/_run_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _execute_transform_wrap(
# TODO: maybe capture warnings?
with open(os.devnull, "w") as null: # noqa: PTH123
with contextlib.redirect_stderr(null):
return _execute_transform(**job)
return _execute_transform(**job, scratch={})


def _get_progress_and_pool(
Expand Down
50 changes: 49 additions & 1 deletion strawberry/codemods/update_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ def __init__(self, context: CodemodContext) -> None:
def _update_imports(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
imports = ["field", "union"]
imports = [
"field",
"union",
"auto",
"unset",
"arguments",
"lazy_type",
"object_type",
"private",
"enum",
]

for import_name in imports:
if m.matches(
Expand All @@ -37,6 +47,29 @@ def _update_imports(

return updated_node

def _update_types_types_imports(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
if m.matches(
node,
m.ImportFrom(
module=m.Attribute(
value=m.Attribute(value=m.Name("strawberry"), attr=m.Name("types")),
attr=m.Name("types"),
)
),
):
updated_node = updated_node.with_changes(
module=cst.Attribute(
value=cst.Attribute(
value=cst.Name("strawberry"), attr=cst.Name("types")
),
attr=cst.Name("base"),
),
)

return updated_node

def _update_strawberry_type_imports(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
Expand All @@ -51,6 +84,11 @@ def _update_strawberry_type_imports(
for name in node.names
)

has_has_object_definition = any(
m.matches(name, m.ImportAlias(name=m.Name("has_object_definition")))
for name in node.names
)

updated_node = updated_node.with_changes(
module=cst.Attribute(
value=cst.Attribute(
Expand All @@ -64,17 +102,27 @@ def _update_strawberry_type_imports(
self.context, "strawberry.types.base", "get_object_definition"
)

self.remove_imports_visitor.remove_unused_import(
self.context, "strawberry.types.base", "has_object_definition"
)

if has_get_object_definition:
self.add_imports_visitor.add_needed_import(
self.context, "strawberry.types", "get_object_definition"
)

if has_has_object_definition:
self.add_imports_visitor.add_needed_import(
self.context, "strawberry.types", "has_object_definition"
)

return updated_node

def leave_ImportFrom(
self, node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
updated_node = self._update_imports(updated_node, updated_node)
updated_node = self._update_types_types_imports(updated_node, updated_node)
updated_node = self._update_strawberry_type_imports(updated_node, updated_node)

return updated_node
9 changes: 8 additions & 1 deletion strawberry/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .base import get_object_definition, has_object_definition
from .execution import ExecutionContext, ExecutionResult
from .info import Info

__all__ = ["ExecutionContext", "ExecutionResult", "Info"]
__all__ = [
"ExecutionContext",
"ExecutionResult",
"Info",
"get_object_definition",
"has_object_definition",
]
91 changes: 90 additions & 1 deletion tests/codemods/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_update_import_strawberry_type_object_definition(self) -> None:
StrawberryType,
WithStrawberryObjectDefinition,
get_object_definition,
has_object_definition,
)
"""

Expand All @@ -59,7 +60,7 @@ def test_update_import_strawberry_type_object_definition(self) -> None:
StrawberryOptional,
StrawberryType,
WithStrawberryObjectDefinition)
from strawberry.types import get_object_definition
from strawberry.types import get_object_definition, has_object_definition
"""

self.assertCodemod(before, after)
Expand All @@ -85,3 +86,91 @@ def test_update_import_union(self) -> None:
"""

self.assertCodemod(before, after)

def test_update_import_auto(self) -> None:
before = """
from strawberry.auto import auto
"""

after = """
from strawberry.types.auto import auto
"""

self.assertCodemod(before, after)

def test_update_import_unset(self) -> None:
before = """
from strawberry.unset import UNSET
"""

after = """
from strawberry.types.unset import UNSET
"""

self.assertCodemod(before, after)

def test_update_import_arguments(self) -> None:
before = """
from strawberry.arguments import StrawberryArgument
"""

after = """
from strawberry.types.arguments import StrawberryArgument
"""

self.assertCodemod(before, after)

def test_update_import_lazy_type(self) -> None:
before = """
from strawberry.lazy_type import LazyType
"""

after = """
from strawberry.types.lazy_type import LazyType
"""

self.assertCodemod(before, after)

def test_update_import_object_type(self) -> None:
before = """
from strawberry.object_type import StrawberryObjectDefinition
"""

after = """
from strawberry.types.object_type import StrawberryObjectDefinition
"""

self.assertCodemod(before, after)

def test_update_import_enum(self) -> None:
before = """
from strawberry.enum import StrawberryEnum
"""

after = """
from strawberry.types.enum import StrawberryEnum
"""

self.assertCodemod(before, after)

def test_update_types_types(self) -> None:
before = """
from strawberry.types.types import StrawberryObjectDefinition
"""

after = """
from strawberry.types.base import StrawberryObjectDefinition
"""

self.assertCodemod(before, after)

def test_update_is_private(self) -> None:
before = """
from strawberry.private import is_private
"""

after = """
from strawberry.types.private import is_private
"""

self.assertCodemod(before, after)

0 comments on commit 8e1fe7d

Please sign in to comment.