Skip to content

Commit

Permalink
run_loop_fusion_test.py: added basic automation and comparison
Browse files Browse the repository at this point in the history
test_loop_fusion.py: Cleaned up format
  • Loading branch information
HannanNaeem committed Jan 31, 2024
1 parent 9eb72b4 commit 3563f68
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
40 changes: 40 additions & 0 deletions tests/run_loop_fusion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import io
import subprocess
import time
from contextlib import redirect_stdout


def main():
cwd = os.getcwd()
subprocess.run(["rm", "-rf", "pk_cpp"], cwd=cwd)

try:
del os.environ["PK_LOOP_FUSE"]
except Exception as e:
print("TRIED DELETING ENV VARIABLE, BUT FAILED:\n", e)

result_vanilla = subprocess.run(["python", "test_loop_fusion.py"], cwd=cwd, capture_output=True, text=True)

# get output
vanilla_out = result_vanilla.stdout

# Again but with env variable
os.environ["PK_LOOP_FUSE"] = "1"
# remove old compilations
subprocess.run(["rm", "-rf", "pk_cpp"], cwd=cwd)

print("rerunning... ")
result_fused = subprocess.run(["python", "test_loop_fusion.py"], cwd=cwd, capture_output=True, text=True)

fused_out = result_fused.stdout

if vanilla_out != fused_out:
print("[X] MISMATCHED OUTPUTS:")
print("\t[-] WITHOUT FUSION:")
print(vanilla_out)
print("\n\t[+] WITH FUSION:")
print(fused_out)

if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions tests/test_loop_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def nested_triples_noprint(tid, v): # removing prints in between loops should al

for k in range(3):
a : int = 2
pk.printf("%d\n", a)
pk.printf("%d ", a)


for j in range(3):
Expand All @@ -82,7 +82,7 @@ def nested_triples_noprint(tid, v): # removing prints in between loops should al

for k in range(3):
a : int = 4
pk.printf("%d\n", a)
pk.printf("%d ", a)


for j in range(3):
Expand All @@ -91,7 +91,7 @@ def nested_triples_noprint(tid, v): # removing prints in between loops should al

for k in range(3):
a : int = 6
pk.printf("%d\n", a)
pk.printf("%d ", a)

@pk.workunit
def view_manip_inbetween(tid, v): # I guess manually inspect c++?
Expand Down

0 comments on commit 3563f68

Please sign in to comment.