Skip to content

Commit

Permalink
Efficient kernels for MoE
Browse files Browse the repository at this point in the history
ghstack-source-id: b8e776f8416cd5edce2c82d15dd734bb22f98c78
Pull Request resolved: https://github.com/fairinternal/xformers/pull/712

__original_commit__ = fairinternal/xformers@4515be06d2659b16fd1e0556c75d37c3a9434ddf
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jul 10, 2023
1 parent 317c039 commit f7aeb35
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import numpy as np
import torch


def assert_allclose(
out: torch.Tensor,
ref: torch.Tensor,
out: Optional[torch.Tensor],
ref: Optional[torch.Tensor],
msg: str = "failed",
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
assert out is not None, f"{msg}: output Tensor is None"
assert ref is not None, f"{msg}: reference Tensor is None"
assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})"
if out.numel() == 0:
return
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_location = np.unravel_index(int(max_pos), out.shape)
Expand Down

0 comments on commit f7aeb35

Please sign in to comment.