Skip to content

Commit

Permalink
Tests: fix naming errors in fusion tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Nov 13, 2023
1 parent 35a66c8 commit d75d323
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def add_views_annotated(i: int, view_1: pk.View1D[int], view_2: pk.View1D[int]):
view_1[i] += view_2[i]

@pk.workunit
def init_view_condition(i, view_1, var):
def init_view_condition(i, view, var):
value: int = 0
if var == 0:
value = 11
else:
value = 22

view_1[i] = value
view[i] = value

@pk.workunit
def init_view_types(i, view, init):
Expand All @@ -45,7 +45,7 @@ def test_fusion(self):
v1 = pk.View((self.iterations,), int)
v2 = pk.View((self.iterations,), int)

pk.parallel_for(self.iterations, [init_view, init_view_annotated], args_1={"view": v1, "init": value1}, args_2={"view": v2, "init": value2})
pk.parallel_for(self.iterations, [init_view, init_view_annotated], args_0={"view": v1, "init": value1}, args_1={"view": v2, "init": value2})
self.assertEqual(v1[0], value1)
self.assertEqual(v2[0], value2)

Expand All @@ -59,14 +59,14 @@ def test_fusion_after_call(self):
pk.parallel_for(self.iterations, add_views, view_1=v1, view_2=v2)
pk.parallel_for(self.iterations, add_views_annotated, view_1=v1, view_2=v2)

pk.parallel_for(self.iterations, [add_views, add_views_annotated], args_1={"view_1": v1, "view_2": v2}, args_2={"view_1": v1, "view_2": v2})
pk.parallel_for(self.iterations, [add_views, add_views_annotated], args_0={"view_1": v1, "view_2": v2}, args_1={"view_1": v1, "view_2": v2})
self.assertEqual(v1[0], 4)

def test_fusion_condition(self):
v1 = pk.View((self.iterations,), int)
v2 = pk.View((self.iterations,), int)

pk.parallel_for(self.iterations, [init_view_condition, init_view_condition], args_1={"view": v1, "init": 0}, args_2={"view": v2, "init": 1})
pk.parallel_for(self.iterations, [init_view_condition, init_view_condition], args_0={"view": v1, "var": 0}, args_1={"view": v2, "var": 1})
self.assertEqual(v1[0], 11)
self.assertEqual(v2[0], 22)

Expand All @@ -76,13 +76,13 @@ def test_fusion_change_types(self):
v1 = pk.View((self.iterations,), int)
v2 = pk.View((self.iterations,), float)

pk.parallel_for(self.iterations, [init_view_types, init_view_types], args_1={"view": v1, "init": value1}, args_2={"view": v2, "init": value2})
pk.parallel_for(self.iterations, [init_view_types, init_view_types], args_0={"view": v1, "init": value1}, args_1={"view": v2, "init": value2})
self.assertEqual(v1[0], value1)
self.assertEqual(v2[0], value2)

pk.parallel_for(self.iterations, [init_view_types, init_view_types], args_1={"view": v2, "init": value2}, args_2={"view": v1, "init": value1})
pk.parallel_for(self.iterations, [init_view_types, init_view_types], args_0={"view": v2, "init": value2}, args_1={"view": v1, "init": value1})
self.assertEqual(v1[0], value1)
self.assertEqual(v2[0], value2)

if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit d75d323

Please sign in to comment.