Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support creating int1 tensors in interpreter #6058

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

saagarjha
Copy link
Contributor

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@saagarjha saagarjha requested a review from ptillet as a code owner February 28, 2025 03:32

@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", map(tl.dtype, tl.dtype.SINT_TYPES + tl.dtype.UINT_TYPES + tl.dtype.STANDARD_FP_TYPES))
def test_dtypes(device, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to differentiate this test with test_dtype. Either merge tests or change one of the function names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, renamed it to test_dtype_tensor (willing to take suggestions on better names if you have any). I'm not entirely sure what test_dtype is supposed to do so I left it alone for now.

def test_dtype_tensor(device, dtype):

@triton.jit
def dtype_tensor_kernel(dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def dtype_tensor_kernel(dtype):
def dtype_tensor_kernel(dtype: tl.constexpr):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants