Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Draft autoscheduler #831

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from enum import Enum

from . import scheduler # noqa: F401
from ._version import __version__, __version_tuple__ # noqa: F401

__array_api_version__ = "2022.12"
Expand Down
40 changes: 40 additions & 0 deletions sparse/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from .finch_logic import (
Aggregate,
Alias,
Deferred,
Field,
Immediate,
MapJoin,
Plan,
Produces,
Query,
Reformat,
Relabel,
Reorder,
Subquery,
Table,
)
from .optimize import optimize, propagate_map_queries
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk

__all__ = [
"Aggregate",
"Alias",
"Deferred",
"Field",
"Immediate",
"MapJoin",
"Plan",
"Produces",
"Query",
"Reformat",
"Relabel",
"Reorder",
"Subquery",
"Table",
"optimize",
"propagate_map_queries",
"PostOrderDFS",
"PostWalk",
"PreWalk",
]
137 changes: 137 additions & 0 deletions sparse/scheduler/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from collections.abc import Hashable
from textwrap import dedent
from typing import Any

from .finch_logic import (
Alias,
Deferred,
Field,
Immediate,
LogicNode,
MapJoin,
Query,
Reformat,
Relabel,
Reorder,
Subquery,
Table,
)


def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any:
if key in dictionary:
return dictionary[key]
dictionary[key] = default
return default

Check warning on line 25 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L22-L25

Added lines #L22 - L25 were not covered by tests


def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode:
match node:
case Field(name):
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases)))
case Alias(name):
return get_or_insert(aliases, name, Immediate(len(fields) + len(aliases)))
case Subquery(Alias(name) as lhs, arg):
if name in aliases:
return aliases[name]
return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases))
case Table(tns, idxs):
return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs))
case any if any.is_tree():
return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()])
case _:
return node

Check warning on line 43 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L29-L43

Added lines #L29 - L43 were not covered by tests


class PointwiseLowerer:
def __init__(self):
self.bound_idxs = []

Check warning on line 48 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L48

Added line #L48 was not covered by tests

def __call__(self, ex):
match ex:
case MapJoin(Immediate(val), args):
return f":({val}({','.join([self(arg) for arg in args])}))"
case Reorder(Relabel(Alias(name), idxs_1), idxs_2):
self.bound_idxs.append(idxs_1)
return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
case Reorder(Immediate(val), _):
return val
case Immediate(val):
return val
case _:
raise Exception(f"Unrecognized logic: {ex}")

Check warning on line 62 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L51-L62

Added lines #L51 - L62 were not covered by tests


def compile_pointwise_logic(ex: LogicNode) -> tuple:
ctx = PointwiseLowerer()
code = ctx(ex)
return (code, ctx.bound_idxs)

Check warning on line 68 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L66-L68

Added lines #L66 - L68 were not covered by tests


def compile_logic_constant(ex: LogicNode) -> str:
match ex:
case Immediate(val):
return val
case Deferred(ex, type_):
return f":({ex}::{type_})"
case _:
raise Exception(f"Invalid constant: {ex}")

Check warning on line 78 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L72-L78

Added lines #L72 - L78 were not covered by tests


def intersect(x1: tuple, x2: tuple) -> tuple:
return tuple(x for x in x1 if x in x2)

Check warning on line 82 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L82

Added line #L82 was not covered by tests


def with_subsequence(x1: tuple, x2: tuple) -> tuple:
res = list(x2)
indices = [idx for idx, val in enumerate(x2) if val in x1]
for idx, i in enumerate(indices):
res[i] = x1[idx]
return tuple(res)

Check warning on line 90 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L86-L90

Added lines #L86 - L90 were not covered by tests


class LogicLowerer:
def __init__(self, mode: str = "fast"):
self.mode = mode

def __call__(self, ex: LogicNode):
match ex:
case Query(Alias(name), Table(tns, _)):
return f":({name} = {compile_logic_constant(tns)})"

Check warning on line 100 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L98-L100

Added lines #L98 - L100 were not covered by tests

case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))):
loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)]
lhs_idxs = [idx.name for idx in idxs_2]
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
for idx in loop_idxs:
if Field(idx) in rhs_idxs:
body = f":(for {idx} = _ \n {body} end)"
elif idx in lhs_idxs:
body = f":(for {idx} = 1:1 \n {body} end)"

Check warning on line 111 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L102-L111

Added lines #L102 - L111 were not covered by tests

result = f"""\

Check warning on line 113 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L113

Added line #L113 was not covered by tests
quote
{lhs.name} = {compile_logic_constant(tns)}
@finch mode = {self.mode} begin
{lhs.name} .= {tns.fill_value}
{body}
return {lhs.name}
end
end
"""
return dedent(result)

Check warning on line 123 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L123

Added line #L123 was not covered by tests

# TODO: ...

case _:
raise Exception(f"Unrecognized logic: {ex}")

Check warning on line 128 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L127-L128

Added lines #L127 - L128 were not covered by tests


class LogicCompiler:
def __init__(self):
self.ll = LogicLowerer()

def __call__(self, prgm):
# prgm = format_queries(prgm, True) # noqa: F821
return self.ll(prgm)

Check warning on line 137 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L137

Added line #L137 was not covered by tests
27 changes: 27 additions & 0 deletions sparse/scheduler/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .compiler import LogicCompiler
from .rewrite_tools import gensym


class LogicExecutor:
def __init__(self, ctx, verbose=False):
self.ctx: LogicCompiler = ctx
self.codes = {}
self.verbose = verbose

def __call__(self, prgm):
prgm_structure = prgm
if prgm_structure not in self.codes:
thunk = logic_executor_code(self.ctx, prgm)
self.codes[prgm_structure] = eval(thunk), thunk

Check warning on line 15 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L12-L15

Added lines #L12 - L15 were not covered by tests

f, code = self.codes[prgm_structure]
if self.verbose:
print(code)
return f(prgm)

Check warning on line 20 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L17-L20

Added lines #L17 - L20 were not covered by tests


def logic_executor_code(ctx, prgm):
# jc = JuliaContext()
code = ctx(prgm)
fname = gensym("compute")
return f""":(function {fname}(prgm) \n {code} \n end)"""

Check warning on line 27 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L25-L27

Added lines #L25 - L27 were not covered by tests
Loading
Loading