Skip to content

Commit

Permalink
add more test case
Browse files Browse the repository at this point in the history
  • Loading branch information
jc-bytedance committed Jul 31, 2023
1 parent 776ba90 commit c346b56
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion test/kernel/test_tensor_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def foo(a: int32[M, N]) -> int32:
np.testing.assert_equal(f(a), foo(a))

def test_simple_load(self):

M = sympy.Symbol('M', positive=True)
N = sympy.Symbol('N', positive=True)

Expand Down Expand Up @@ -111,3 +110,49 @@ def foo(a: int32[M, N]) -> int32[2, 2]:
foo(a)
f = compile_linalg(p)
np.testing.assert_equal(f(a), foo(a))

def test_constant_slice_tensor_return2(self):
M = sympy.Symbol('M', positive=True)
N = sympy.Symbol('N', positive=True)

def foo(a: int32[M, N]):
return a[:2, :2]

p = KernelParser(foo)
p.parse()
print()
print("=" * 30, "linalg_code", "=" * 30, sep="")
print()
print(p.linalg_code())
print()
print("=" * 30, "compile and run", "=" * 30, sep="")
print()
a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
foo(a)
f = compile_linalg(p)
np.testing.assert_equal(f(a), foo(a))


"""
def test_constant_slice_tensor_return3(self):
M = sympy.Symbol('M', positive=True)
N = sympy.Symbol('N', positive=True)
def foo(a: int32[M, N], b: int32):
return a[b:b + 1, b:b * 2]
p = KernelParser(foo)
p.parse()
print()
print("=" * 30, "linalg_code", "=" * 30, sep="")
print()
print(p.linalg_code())
print()
print("=" * 30, "compile and run", "=" * 30, sep="")
print()
a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
b = 1
foo(a, 1)
f = compile_linalg(p)
np.testing.assert_equal(f(a), foo(a, 1))
"""

0 comments on commit c346b56

Please sign in to comment.