Skip to content

Commit

Permalink
Improve subclass support for orig in object.__new__
Browse files Browse the repository at this point in the history
  • Loading branch information
ionite34 committed Feb 5, 2023
1 parent be8b10a commit 2b38bc3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
28 changes: 20 additions & 8 deletions src/einspect/structs/py_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, tp_new: newfunc, wrap_type: Type[_T]):
# Cast tp_new to remove Structure binding
self._tp_new = cast(tp_new, newfunc)
self._type = wrap_type
self.__name__ = "__new__"

def __repr__(self):
return (
Expand Down Expand Up @@ -257,11 +258,22 @@ def __call__(self, *args: tuple, **kwds: dict):
# to instead check for type identity.
# This is so orig().__new__ can be called within a custom __new__.
# Semantically, this is the same as the original check.
if staticbase and staticbase[0] != PyTypeObject.from_object(self._type):
raise TypeError(
f"{self.__name__}.__new__({subtype.__name__}): "
f"is not safe, use {self.__name__}.__new__()"
)

args = args[1:]
return self._tp_new(subtype, args, kwds)
# Also bypass this check if the type is a heap type, and _type is object / type.
if staticbase and (staticbase_obj := staticbase[0]):
if (staticbase_obj.tp_flags & TpFlags.HEAPTYPE) and (
self._type is object or self._type is type
):
staticbase_obj = None
if staticbase_obj and staticbase_obj != PyTypeObject.from_object(
self._type
):
raise TypeError(
f"{self._type.__name__}.__new__({subtype.__name__}): "
f"is not safe, use {staticbase[0].tp_name}.__new__()"
)

# object.__new__ takes no arguments, so don't pass any if we're calling it
if self._type is object:
return self._tp_new(subtype, (), {})

return self._tp_new(subtype, args[1:], kwds)
29 changes: 23 additions & 6 deletions src/einspect/type_orig.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Proxy for retrieving original methods and slot wrappers of types."""
from __future__ import annotations

from ctypes import cast
from types import BuiltinFunctionType
from typing import Any, Type, TypeVar

from einspect.structs.include.object_h import newfunc
from einspect.structs.py_type import PyTypeObject, TypeNewWrapper

_T = TypeVar("_T")
MISSING = object()

obj_tp_new = cast(PyTypeObject.from_object(object).tp_new, newfunc)
obj_getattr = object.__getattribute__
type_hash = type.__hash__
str_eq = str.__eq__
Expand All @@ -19,17 +23,25 @@
_slots_cache: dict[type, dict[str, Any]] = {}


def add_cache(type_: Type[_T], name: str, method: Any) -> None:
def add_cache(type_: Type[_T], name: str, method: Any) -> Any:
"""Add a method to the cache."""
type_methods = dict_setdefault(_slots_cache, type_, {})

# For `__new__` methods, use special TypeNewWrapper to use modified safety check
if name == "__new__":
tp_new = PyTypeObject(type_).tp_new
method = TypeNewWrapper(tp_new=tp_new, wrap_type=type_)
# Check if we're trying to set a previous impl method
# If so, avoid the loop by using object.__new__
if not isinstance(method, BuiltinFunctionType):
method = get_cache(object, "__new__")
else:
tp_new = PyTypeObject.from_object(type_).tp_new
obj = obj_tp_new(TypeNewWrapper, (), {})
obj.__init__(tp_new, type_)
method = obj

# Only allow adding once, ignore if already added
dict_setdefault(type_methods, name, method)
return method


def in_cache(type_: type, name: str) -> bool:
Expand All @@ -44,6 +56,10 @@ def get_cache(type_: type, name: str) -> Any:
return dict_getitem(type_methods, name)


add_cache(object, "__new__", object.__new__)
add_cache(type, "__new__", type.__new__)


class orig:
"""
Proxy to access a type's original attributes.
Expand All @@ -53,7 +69,9 @@ class orig:
"""

def __new__(cls, type_: Type[_T]) -> Type[_T]:
obj = object.__new__(cls)
# To avoid a circular call loop when orig is called within
# impl of object.__new__, we use the raw tp_new of object here.
obj = obj_tp_new(cls, (), {})
obj.__type = type_
return obj # type: ignore

Expand All @@ -78,5 +96,4 @@ def __getattribute__(self, name: str):
pass
# Get the attribute from the original type and cache it
attr = getattr(_type, name)
add_cache(_type, name, attr)
return attr
return add_cache(_type, name, attr)
14 changes: 8 additions & 6 deletions src/einspect/views/view_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def wrapper(func: _Fn) -> _Fn:
name = func.__name__

for type_ in targets:
with TypeView(type_).alloc_mode(alloc) as t_view:
t_view = TypeView(type_)
with t_view.alloc_mode(alloc):
t_view[name] = func

return func
Expand Down Expand Up @@ -196,21 +197,22 @@ def __setitem__(self, key: str | tuple[str, ...], value: Any) -> None:
# For all alloc mode, allocate now
if self._alloc_mode == "all":
self.alloc_slot()
for k in keys:
for name in keys:
# Cache original implementation
base = self.base
if not in_cache(base, k):
if not in_cache(base, name):
with suppress(AttributeError):
add_cache(base, k, getattr(base, k))
attr = getattr(base, name)
add_cache(base, name, attr)
# Check if this is a slots attr (skip all since we already allocated)
if self._alloc_mode != "all" and (
slot := get_slot(k, prefer=self._alloc_mode)
slot := get_slot(name, prefer=self._alloc_mode)
):
# Allocate sub-struct if needed
self._try_alloc(slot)

with self.as_mutable():
self._pyobject.setattr_safe(k, value)
self._pyobject.setattr_safe(name, value)

# <-- Begin Managed::Properties (structs::py_type.PyTypeObject) -->

Expand Down

0 comments on commit 2b38bc3

Please sign in to comment.