Skip to content

Commit

Permalink
add some infrasture stuff for replacing nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 29, 2024
1 parent a4c2fe7 commit d416c34
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 22 deletions.
5 changes: 5 additions & 0 deletions luisa_lang/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,8 @@ def get_union_args(union: Any) -> List[type]:
if hasattr(union, "__args__") or isinstance(union, types.UnionType):
return list(union.__args__)
return []

def checked_cast(t: type[T], obj: Any) -> T:
if not isinstance(obj, t):
raise TypeError(f"expected {t}, got {type(obj)}")
return obj
207 changes: 186 additions & 21 deletions luisa_lang/hir/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
Callable,
List,
Optional,
Set,
Tuple,
Dict,
Union,
cast,
)
from typing_extensions import override
from luisa_lang._utils import Span
Expand Down Expand Up @@ -174,21 +176,6 @@ def __hash__(self) -> int:
return hash((FloatType, self.bits))


# INT8: Final[IntType] = IntType(8, True)
# INT16: Final[IntType] = IntType(16, True)
# INT32: Final[IntType] = IntType(32, True)
# INT64: Final[IntType] = IntType(64, True)

# UINT8: Final[IntType] = IntType(8, False)
# UINT16: Final[IntType] = IntType(16, False)
# UINT32: Final[IntType] = IntType(32, False)
# UINT64: Final[IntType] = IntType(64, False)

# FLOAT16: Final[FloatType] = FloatType(16)
# FLOAT32: Final[FloatType] = FloatType(32)
# FLOAT64: Final[FloatType] = FloatType(64)


class VectorType(Type):
element: ScalarType
count: int
Expand Down Expand Up @@ -443,22 +430,101 @@ def align(self) -> int:
raise RuntimeError("FunctionType has no align")


class Use:
value: 'Node'
user: 'Node'

def __init__(self, value: 'Node', user: 'Node') -> None:
self.value = value
self.user = user


class UseList:
uses: List[Use]
user_to_value: Dict['Node', 'Node']
value_to_user: Dict['Node', 'Node']

def __init__(self) -> None:
self.uses = []
self.user_to_value = {}
self.value_to_user = {}

def append(self, use: Use) -> None:
self.uses.append(use)
self.user_to_value[use.user] = use.value
self.value_to_user[use.value] = use.user


def rebuild_usedef_chain(roots: List['Node']) -> None:
# first find all reachable nodes
def find_reachable() -> List['Node']:
reachable = []
stack = list(roots)
while stack:
node = stack.pop()
if node in reachable:
continue
reachable.append(node)
stack.extend(node.children())
return reachable
# clear all uses
for node in find_reachable():
node.uses = []
# rebuild all uses
for node in find_reachable():
for child in node.children():
node.uses.append(Use(child, node))


class Node:
type: Optional[Type]
"""
Base class for all nodes in the HIR. A node could be a value, a reference, or a statement.
Nodes equality is based on their identity.
"""
uses: List[Use]
span: Optional[Span]

def __init__(self) -> None:
self.uses = []
self.span = None

def replace_child(self, old: 'Node', new: 'Node') -> None:
pass

def children(self) -> List['Node']:
"""Return a list of children nodes."""
return []

def __eq__(self, value: object) -> bool:
return value is self

def __hash__(self) -> int:
return id(self)

def replace_uses_with(self, new: 'Node') -> None:
for use in self.uses:
use.user.replace_child(self, new)


class TypedNode(Node):
"""
A node with a type, which can either be values or references.
"""
type: Optional[Type]

def __init__(
self, type: Optional[Type] = None, span: Optional[Span] = None
) -> None:
super().__init__()
self.type = type
self.span = span


class Ref(Node):
class Ref(TypedNode):
pass


class Value(Node):
class Value(TypedNode):
pass


Expand All @@ -479,6 +545,16 @@ def __init__(self, value: Value) -> None:
super().__init__(value.type, value.span)
self.value = value

@override
def replace_child(self, old: Node, new: Node) -> None:
assert isinstance(new, Value)
if old is self.value:
self.value = new

@override
def children(self) -> List[Node]:
return [self.value]


class Var(Ref):
name: str
Expand All @@ -502,6 +578,16 @@ def __init__(self, base: Ref, field: str, span: Optional[Span]) -> None:
self.base = base
self.field = field

@override
def replace_child(self, old: Node, new: Node) -> None:
assert isinstance(new, Ref)
if old is self.base:
self.base = new

@override
def children(self) -> List[Node]:
return [self.base]


class Index(Ref):
base: Ref
Expand All @@ -512,6 +598,19 @@ def __init__(self, base: Ref, index: Value, span: Optional[Span]) -> None:
self.base = base
self.index = index

@override
def replace_child(self, old: Node, new: Node) -> None:
if old is self.index:
assert isinstance(new, Value)
self.index = new
elif old is self.base:
assert isinstance(new, Ref)
self.base = new

@override
def children(self) -> List[Node]:
return [self.base, self.index]


class Load(Value):
ref: Ref
Expand All @@ -520,6 +619,16 @@ def __init__(self, ref: Ref) -> None:
super().__init__(ref.type, ref.span)
self.ref = ref

@override
def replace_child(self, old: Node, new: Node) -> None:
assert isinstance(new, Ref)
if old is self.ref:
self.ref = new

@override
def children(self) -> List[Node]:
return [self.ref]


class CallOpKind(Enum):
FUNC = auto()
Expand Down Expand Up @@ -555,6 +664,24 @@ def __init__(
self.kind = kind
self.resolved = resolved

@override
def replace_child(self, old: Node, new: Node) -> None:
assert isinstance(new, Value)
if old is self.op:
self.op = new
else:
for i, arg in enumerate(self.args):
if old is arg:
self.args[i] = new

@override
def children(self) -> List[Node]:
lst = []
if isinstance(self.op, Value):
lst.append(self.op)
lst.extend(self.args)
return cast(List[Node], lst)


class TypeInferenceError(Exception):
pass
Expand Down Expand Up @@ -589,10 +716,9 @@ def __init__(self, name: str, type_rule: TypeRule) -> None:
self.type_rule = type_rule


class Stmt:
span: Optional[Span]

class Stmt(Node):
def __init__(self, span: Optional[Span] = None) -> None:
super().__init__()
self.span = span


Expand All @@ -605,6 +731,12 @@ def __init__(self, var: Var, expected_type: Type, span: Optional[Span] = None) -
self.var = var
self.expected_type = expected_type

@override
def replace_child(self, old: Node, new: Node) -> None:
assert isinstance(new, Var)
if old is self.var:
self.var = new


class Assign(Stmt):
ref: Ref
Expand All @@ -617,6 +749,19 @@ def __init__(self, ref: Ref, expected_type: Optional[Type], value: Value, span:
self.value = value
self.expected_type = expected_type

@override
def replace_child(self, old: Node, new: Node) -> None:
if old is self.ref:
assert isinstance(new, Ref)
self.ref = new
elif old is self.value:
assert isinstance(new, Value)
self.value = new

@override
def children(self) -> List[Node]:
return [self.ref, self.value]


class Return(Stmt):
value: Optional[Value]
Expand All @@ -625,6 +770,16 @@ def __init__(self, value: Optional[Value], span: Optional[Span] = None) -> None:
super().__init__(span)
self.value = value

@override
def replace_child(self, old: Node, new: Node) -> None:
if old is self.value:
assert isinstance(new, Value)
self.value = new

@override
def children(self) -> List[Node]:
return [self.value] if self.value is not None else []


class Function:
name: str
Expand Down Expand Up @@ -666,6 +821,16 @@ def is_parametric(self) -> bool:
return True
return False

def rebuild_usedef_chain(self) -> None:
roots: Set[Node] = set()
for param in self.params:
roots.add(param)
for stmt in self.body:
roots.add(stmt)
for local in self.locals:
roots.add(local)
rebuild_usedef_chain(list(roots))


# K = TypeVar("K")
# V = TypeVar("V")
Expand Down
2 changes: 1 addition & 1 deletion luisa_lang/hir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TypeInferencer:


def _infer_cache(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(inferencer: 'FuncTypeInferencer', node: hir.Node, *args) -> Optional[hir.Type]:
def wrapper(inferencer: 'FuncTypeInferencer', node: hir.TypedNode, *args) -> Optional[hir.Type]:
if id(node) in inferencer._cache:
return inferencer._cache[id(node)]
if node.type:
Expand Down

0 comments on commit d416c34

Please sign in to comment.