Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Jul 23, 2024
1 parent e9d4acf commit d8606ef
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 32 deletions.
50 changes: 49 additions & 1 deletion luisa_lang/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional, TypeVar
import ast
import textwrap
from typing import Optional, Tuple, TypeVar
import sourceinspect

T = TypeVar("T")

Expand All @@ -7,3 +10,48 @@ def unwrap(opt: Optional[T]) -> T:
if opt is None:
raise ValueError("unwrap None")
return opt


def increment_lineno_and_col_offset(
node: ast.AST, lineno: int, col_offset: int
) -> ast.AST:
"""
Increment the line number and end line number of each node in the tree
starting at *node* by *n*. This is useful to "move code" to a different
location in a file.
"""
for child in ast.walk(node):
if "lineno" in child._attributes:
child.lineno = getattr(child, "lineno", 0) + lineno
if (
"end_lineno" in child._attributes
and (end_lineno := getattr(child, "end_lineno", 0)) is not None
):
child.end_lineno = end_lineno + lineno
if "col_offset" in child._attributes:
child.col_offset = getattr(child, "col_offset", 0) + col_offset
if (
"end_col_offset" in child._attributes
and (end_col_offset := getattr(child, "end_col_offset", 0)) is not None
):
child.end_col_offset = end_col_offset + col_offset
return node


def dedent_and_retrieve_indentation(lines: str) -> Tuple[str, int]:
"""
Dedent the lines and return the indentation level of the first line.
"""
if not lines:
return "", 0
return textwrap.dedent("".join(lines)), len(lines[0]) - len(lines[0].lstrip())


def retrieve_ast_and_filename(f: object) -> Tuple[ast.AST, str]:
source_file = sourceinspect.getsourcefile(f)
if source_file is None:
source_file = "<unknown>"
source_lines, lineno = sourceinspect.getsourcelines(f)
src, indent = dedent_and_retrieve_indentation(source_lines)
tree = increment_lineno_and_col_offset(ast.parse(src), lineno - 1, indent + 1)
return tree, source_file
13 changes: 13 additions & 0 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __eq__(self, value: object) -> bool:
and value.signed == self.signed
)

def __repr__(self) -> str:
return f"IntType({self.bits}, {self.signed})"


class FloatType(ScalarType):
bits: int
Expand All @@ -121,6 +124,9 @@ def align(self) -> int:
def __eq__(self, value: object) -> bool:
return isinstance(value, FloatType) and value.bits == self.bits

def __repr__(self) -> str:
return f"FloatType({self.bits})"


# INT8: Final[IntType] = IntType(8, True)
# INT16: Final[IntType] = IntType(16, True)
Expand Down Expand Up @@ -171,6 +177,9 @@ def __eq__(self, value: object) -> bool:
and value.count == self.count
)

def __repr__(self) -> str:
return f"VectorType({self.element}, {self.count})"


class ArrayType(Type):
element: Type
Expand All @@ -192,6 +201,9 @@ def __eq__(self, value: object) -> bool:
and value.element == self.element
and value.count == self.count
)

def __repr__(self) -> str:
return f"ArrayType({self.element}, {self.count})"


class StructType(Type):
Expand Down Expand Up @@ -545,6 +557,7 @@ def __init__(self) -> None:
self.types = {}
self.functions = {}


class FuncMetadata:
pass

Expand Down
68 changes: 53 additions & 15 deletions luisa_lang/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from luisa_lang._math_type_exports import *
from luisa_lang._markers import _builtin_type, _builtin, _intrinsic_impl
import luisa_lang.hir as hir
import luisa_lang.parse as parse
import ast

_T = TypeVar("_T")
Expand All @@ -29,23 +31,40 @@ class _ObjKind(Enum):
KERNEL = auto()


def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
def _dsl_func_impl(f: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
import sourceinspect
obj_src = sourceinspect.getsource(obj)
obj_ast = ast.parse(obj_src)
from luisa_lang._utils import retrieve_ast_and_filename
import inspect

assert inspect.isfunction(f), f"{f} is not a function"
obj_ast, obj_file = retrieve_ast_and_filename(f)
assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module"

ctx = hir.GlobalContext.get()

func_globals: Any = getattr(f, "__globals__", {})
parsing_ctx = parse.ParsingContext(func_globals)
print(ast.dump(obj_ast))
func_def = obj_ast.body[0]
if not isinstance(func_def, ast.FunctionDef):
raise RuntimeError("Function definition expected.")
func_parser = parse.FuncParser(func_def, parsing_ctx)
if kind == _ObjKind.FUNC:

def dummy(*args, **kwargs):
raise RuntimeError("DSL function should only be called in DSL context.")

return cast(_T, dummy)
else:
return cast(_T, f)


def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:

if kind == _ObjKind.STRUCT:
return obj
elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL:
assert callable(obj)
func_globals = getattr(obj, "__globals__", None)
if kind == _ObjKind.FUNC:

def dummy(*args, **kwargs):
raise RuntimeError("DSL function should only be called in DSL context.")

return cast(_T, dummy)
else:
return cast(_T, obj)
return _dsl_func_impl(obj, kind, attrs)
raise NotImplementedError()


Expand Down Expand Up @@ -109,16 +128,35 @@ def swap(a: int, b: int):
a, b = b, a
```
"""

def impl(f: _F) -> _F:
return _dsl_decorator_impl(f, _ObjKind.FUNC, kwargs)

if len(args) == 1 and len(kwargs) == 0:
f = args[0]
return f
return impl(f)

def decorator(f):
return f
return impl(f)

return decorator


@_builtin_type
class Array(Generic[_T]):
def __init__(self, size: u32 | u64) -> None:
return _intrinsic_impl()

def __getitem__(self, index: int | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

def __len__(self) -> u32 | u64:
return _intrinsic_impl()


@_builtin_type
class Buffer(Generic[_T]):
def __getitem__(self, index: int | u32 | u64) -> _T:
Expand Down
Loading

0 comments on commit d8606ef

Please sign in to comment.