Skip to content

Commit

Permalink
whatever progress
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Jan 6, 2025
1 parent 6cc1f1d commit f2f9481
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 53 deletions.
107 changes: 100 additions & 7 deletions src/exo/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from ..core.configs import reverse_config_lookup, Config
from .new_analysis_core import *
from ..core.proc_eqv import get_repr_proc
from .dataflow import LoopIR_to_DataflowIR, ScalarPropagation, D
from .dataflow import (
LoopIR_to_DataflowIR,
ScalarPropagation,
D,
adom_to_aexpr,
DataflowIR,
)


def lift_dexpr():
Expand Down Expand Up @@ -1662,7 +1668,85 @@ def bds(x, lo, hi):


def Check_DeleteConfigWrite(proc, stmts):
assert len(stmts) == 1
assert len(stmts) == 2

p = GetControlPredicates(proc, stmts).result()
slv = SMTSolver(verbose=False)
slv.push()
slv.assume(AMay(p))

# Remain the existing pre-checking
ap = PostEnv(proc, stmts).get_posteffs()
a = [E.Guard(AMay(p), stmts_effs(stmts))]

# extract effects
WrA, Mod = getsets([ES.WRITE_ALL, ES.MODIFY], a)
WrAp, RdAp = getsets([ES.WRITE_ALL, ES.READ_ALL], ap)
print(WrA)
print(RdAp)

# check that `stmts` does not modify any non-global data
# only_mod_glob = ADef(is_empty(LDiff(Mod, WrG)))
# is_ok = slv.verify(only_mod_glob)
# if not is_ok:
# slv.pop()
# raise SchedulingError(
# f"Cannot delete or insert statements at {stmts[0].srcinfo} "
# f"because they may modify non-configuration data"
# )

# Below are the actual checks
ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
ScalarPropagation(ir1)
print(ir1)
if isinstance(d_stmts[0][0], DataflowIR.For):
prev_nm = d_stmts[0][-1].lhs
post_nm = d_stmts[1][-1].lhs
else:
prev_nm = d_stmts[0][0].lhs
post_nm = d_stmts[1][0].lhs
prev_val = adom_to_aexpr(prev_nm, ir1.body.ctxt[prev_nm])
post_val = adom_to_aexpr(post_nm, ir1.body.ctxt[post_nm])
cfg_mod = {pt.name: pt for pt in get_point_exprs(WrA)}

# consider every global that might be modified
cfg_mod_visible = set()
for _, pt in cfg_mod.items():
pt_e = A.Var(pt.name, T.R, null_srcinfo())
is_written = is_elem(pt, WrA)
is_read_post = is_elem(pt, RdAp)
is_overwritten = is_elem(pt, WrAp)

prev_k = A.Var(prev_nm, T.int, null_srcinfo())
post_k = A.Var(post_nm, T.int, null_srcinfo())

is_unchanged = AImplies(AAnd(prev_val, post_val), AEq(prev_k, post_k))
print(is_unchanged)

# if the value of the global might be read,
# then it must not have been changed.
safe_write = AImplies(AMay(is_read_post), ADef(is_unchanged))
print(safe_write)
if not slv.verify(is_unchanged):
slv.pop()
raise SchedulingError(
f"Cannot change configuration value of {pt.name} "
f"at {stmts[0].srcinfo}; the new (and different) "
f"values might be read later in this procedure"
)
# the write is invisible if its definitely unchanged or definitely
# overwritten
invisible = ADef(AOr(is_unchanged, is_overwritten))
if not slv.verify(invisible):
cfg_mod_visible.add(pt.name)

slv.pop()
return cfg_mod_visible


"""
def Check_DeleteConfigWrite(proc, stmts):
assert len(stmts) == 2
p = GetControlPredicates(proc, stmts).result()
slv = SMTSolver(verbose=False)
Expand Down Expand Up @@ -1690,7 +1774,10 @@ def Check_DeleteConfigWrite(proc, stmts):
# Below are the actual checks
ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
ScalarPropagation(ir1)
config1_vals, config2_vals = GetValues(ir1, d_stmts).result()
prev_nm = d_stmts[0][0].lhs
post_nm = d_stmts[1][0].lhs
prev_val = adom_to_aexpr(prev_nm, ir1.body.ctxt[prev_nm])
post_val = adom_to_aexpr(post_nm, ir1.body.ctxt[post_nm])
cfg_mod = {pt.name: pt for pt in get_point_exprs(WrG)}
# consider every global that might be modified
Expand All @@ -1701,15 +1788,20 @@ def Check_DeleteConfigWrite(proc, stmts):
is_read_post = is_elem(pt, RdGp)
is_overwritten = is_elem(pt, WrGp)
akey = A.Var(pt.name.copy(), T.int, null_srcinfo()) # type and srcinfo not sure
aval1 = lift_dexpr(config1_vals[pt.name], key=akey)
aval2 = lift_dexpr(config2_vals[pt.name], key=akey)
prev_k = A.Var(prev_nm, T.int, null_srcinfo())
post_k = A.Var(post_nm, T.int, null_srcinfo())
print(repr(prev_k))
print(repr(post_k))
is_unchanged = AAnd(AImplies(aval1, aval2), AImplies(aval2, aval1))
# FIXME: Change this! wrong
is_unchanged = AImplies(AAnd(prev_val, post_val), AEq(prev_k, post_k))
print(pt_e)
print(is_unchanged)
# if the value of the global might be read,
# then it must not have been changed.
safe_write = AImplies(AMay(is_read_post), ADef(is_unchanged))
print(safe_write)
if not slv.verify(safe_write):
slv.pop()
raise SchedulingError(
Expand All @@ -1725,6 +1817,7 @@ def Check_DeleteConfigWrite(proc, stmts):
slv.pop()
return cfg_mod_visible
"""


# This equivalence check assumes that we can
Expand Down
Loading

0 comments on commit f2f9481

Please sign in to comment.