diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 07b707211..83efc8d9d 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -39,15 +39,14 @@ def __init__( page_cache_size: int = 128, # Need to look at the model more for this. end_token: int = 2, + dump_bins: bool = False, ): self.model = model self.tokenizer = tokenizer - if self.model.config.kv_cache_type == "paged": - self.shared_cache_state = model.cache.allocate(page_cache_size) - self.free_pages = list(range(1, page_cache_size)) - else: - self.shared_cache_state = None + self.shared_cache_state = model.cache.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) self.end_token = end_token + self.dump_bins = dump_bins @property def block_seq_stride(self) -> int: @@ -64,13 +63,9 @@ def begin_batch(self, prompts: list[str]): cache_state = self.shared_cache_state else: cache_state = self.model.cache.allocate(bs=len(prompts)) - return Batch(self, token_ids, seq_lens, cache_state) + return Batch(self, token_ids, seq_lens, cache_state, dump_bins=self.dump_bins) def alloc_page(self) -> int: - if self.model.config.kv_cache_type == "direct": - # We don't allocate block ids for the direct cache. - return 0 - return self.free_pages.pop() def release_page(self, index: int): @@ -86,6 +81,7 @@ def __init__( token_ids: torch.Tensor, seq_lens: torch.Tensor, cache_state: list[torch.Tensor], + dump_bins: bool = False, ): self.bs = token_ids.shape[0] assert seq_lens.shape[0] == self.bs @@ -95,6 +91,7 @@ def __init__( self.cache_state = cache_state self.results: list[list[int]] = [[] for _ in range(self.bs)] self.done_result_indices: set[int] = set() + self.dump_bins = dump_bins # Assemble the batch. seq_stride = self.parent.block_seq_stride @@ -160,6 +157,23 @@ def prefill(self): attention_mask = replicate(attention_mask, tp) seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + if self.dump_bins: + torch.save( + token_ids, + f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin", + ) + torch.save( + torch.tensor(token_ids.shape[0]).to(torch.int64), + f"prefill_seq_lens_1.bin", + ) + torch.save( + seq_block_ids_tensor, + f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin", + ) + torch.save( + self.cache_state[0].to(torch.float8_e4m3fnuz), + f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + ) logits = model.prefill( token_ids, attention_mask=attention_mask, @@ -204,6 +218,27 @@ def decode(self): seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) decode_attention_mask = replicate(decode_attention_mask, tp) + if self.dump_bins: + torch.save( + self.next_tokens, + f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin", + ) + torch.save( + start_positions, + f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin", + ) + torch.save( + seq_block_ids_tensor, + f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin", + ) + torch.save( + torch.tensor(self.next_tokens.shape[0]).to(torch.int64), + f"decode_seq_lens_1.bin", + ) + torch.save( + self.cache_state[0].to(torch.float8_e4m3fnuz), + f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + ) logits = model.decode( self.next_tokens, attention_mask=decode_attention_mask, @@ -238,6 +273,11 @@ def main(): "--save_intermediates_path", help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors", ) + parser.add_argument( + "--dump-bins", + help="dump input tensors to bin files", + action="store_true", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) cli.add_quantization_options(parser) @@ -274,7 +314,7 @@ def main(): intermediates_saver = SaveModuleResultTensorsPatch() intermediates_saver.patch_child_modules(model) - generator = TorchGenerator(model, tokenizer) + generator = TorchGenerator(model, tokenizer, dump_bins=args.dump_bins) print(f":: Prompting:") for prompt in prompts: diff --git a/sharktank/sharktank/kernels/bitcast.py b/sharktank/sharktank/kernels/bitcast.py index 66850008f..a1552a92b 100644 --- a/sharktank/sharktank/kernels/bitcast.py +++ b/sharktank/sharktank/kernels/bitcast.py @@ -31,6 +31,7 @@ ] _ftype_to_ctype_table = { + torch.bfloat16: torch.complex32, torch.float16: torch.complex32, torch.float32: torch.complex64, } diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 6cf79402e..37b279daf 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]): name_prefix = p.get("general.architecture", "llama") default_expert_count = 0 default_expert_used_count = 0 - default_rope_freq_base = 10000.0 + default_rope_freq_base = 500000.0 default_rope_dimension_count = 128 attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count") rope_dimension_count = _optional_int_prop( diff --git a/sharktank/sharktank/layers/ffn_block.py b/sharktank/sharktank/layers/ffn_block.py index fb8ecc2bb..a4e6a8bca 100644 --- a/sharktank/sharktank/layers/ffn_block.py +++ b/sharktank/sharktank/layers/ffn_block.py @@ -25,15 +25,20 @@ def __init__( theta: Theta, is_gated: bool = True, activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu, + fake_quant: bool = False, ): super().__init__(theta) self.is_gated = is_gated self.activation_fn = activation_fn if self.is_gated: - self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) - self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) - self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) + self.add_module( + "ffn_gate", LinearLayer(theta("ffn_gate"), fake_quant=fake_quant) + ) + self.add_module("ffn_up", LinearLayer(theta("ffn_up"), fake_quant=fake_quant)) + self.add_module( + "ffn_down", LinearLayer(theta("ffn_down"), fake_quant=fake_quant) + ) def forward( self, diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index acd9b8a37..a1f1366ab 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -78,7 +78,6 @@ def forward(self, x): x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) - # Unconditionally dequantize. if isinstance(y, QuantizedTensor): y = y.unpack().dequant() @@ -88,7 +87,7 @@ def forward(self, x): # level to do this, but for now its here. if not isinstance(y, QuantizedTensor): if y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.float16) + y = ops.to(y, torch.bfloat16) return y if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index a93649b29..ff4d6e873 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -34,7 +34,7 @@ def __init__( def forward(self, x: torch.Tensor): orig_dtype = x.dtype x = ops.to(x, self.dtype) - norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon) + norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon, orig_dtype=orig_dtype) # Will automatically upcast to the dtype of the weight, which is # often in higher precision. Downcast back to expected. norm = ops.to(norm, orig_dtype) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 0c937f39d..69b011cc4 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -100,9 +100,7 @@ def forward( cache_state: list[torch.Tensor] = None, ): assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None) - x = self.attn_norm(h) - bs, batch_seq_len, feature_dim = x.shape assert feature_dim == self.head_count * self.head_dim @@ -128,22 +126,7 @@ def forward( # Used by fp8_e4m3fnuz model if self.cache_quantizer is not None: - # For fake quant, store the fp16 qdq value in the cache - if self.fake_quant: - xk = ( - self.cache_quantizer.quantize(xk) - .unpack() - .dequant() - .to(torch.float16) - ) - xv = ( - self.cache_quantizer.quantize(xv) - .unpack() - .dequant() - .to(torch.float16) - ) - # For real quant, store the quantized fp8 value in the cache - else: + if not self.fake_quant: # TODO: this seems like a bastardization of our quantized tensor api # Probably want to add support for using quantized tensors more directly xk = self.cache_quantizer.quantize(xk).unpack().qs @@ -175,11 +158,14 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Fake quant is already dequantized when stored in the cache. if self.cache_quantizer and not self.fake_quant: xk = self.cache_quantizer.dequantize_raw_tensor( - xk, torch.float16, name="xk_deq" + xk, torch.bfloat16, name="xk_deq" ) xv = self.cache_quantizer.dequantize_raw_tensor( - xv, torch.float16, name="xv_deq" + xv, torch.bfloat16, name="xv_deq" ) + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bfloat16) + # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) keys = xk.transpose(1, 2) @@ -223,7 +209,6 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_output = attn_output.transpose(1, 2) attn_output = attn_output.flatten(2, 3) - # Project. attn_output = self.attn_output(attn_output) attn_output = self.attn_output_norm(attn_output) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index a5003cd46..a31913833 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -18,6 +18,10 @@ from ...utils.create_cache import * from ... import ops + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + __all__ = [ "PagedLlamaModelV1", ] @@ -82,7 +86,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module( "token_embedding", - TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + TokenEmbeddingLayer(theta("token_embd"), dtype=self.activation_dtype), ) self.add_module( "attention_embedding", @@ -93,6 +97,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): device=self.device, use_hf=self.use_hf, tensor_parallelism_size=config.tensor_parallelism_size, + dtype=config.activation_dtype, ), ) self.add_module( @@ -258,6 +263,7 @@ def __init__( "ffn", FFN( theta=theta, + fake_quant=fake_quant, ), ) self.add_module( diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 052593748..6ca19d5bd 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -113,7 +113,7 @@ def apply_per_layer_quant( # It looks dumb but, this step is required for numerical correctness against quark. # weight = weight.view(torch.float8_e4m3fn) - weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16) + weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.bfloat16) weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point") if weight_quant_zero_point == None: @@ -148,6 +148,7 @@ def quantize_weight( weight_quant = weight_quantizer.quantize(weight, name=weight_name) updated_tensors[weight_quant.name] = weight_quant + # In older quark models the qkv layer is fused. Unfuse. if "qkv" in layer_name: # The qkv layer is fused in the quark model, decompose back into individual q, k , and v weights q_weight, k_weight, v_weight = torch.split(weight, split_sizes) @@ -275,8 +276,6 @@ def single_replace( updated_tensors: dict[str, InferenceTensor], ): data = quant_theta(layer_name).tensor("weight").as_torch() - if data.dtype == torch.bfloat16: - data = data.to(torch.float32) updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index f83fae089..bfae8e273 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -400,11 +400,14 @@ def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> # RMS norm @rms_norm.override(AllOfType(Tensor, InferenceTensor)) -def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: +def rms_norm_default( + x, weight, *, epsilon: float, orig_dtype: Union[None, torch.dtype] +) -> Tensor: + if orig_dtype is None: + orig_dtype = x.dtype variance = x.pow(2).mean(-1, keepdim=True) output = x * elementwise(torch.rsqrt, variance + epsilon) - # The cast here is to match the hf implementation, affects numerics - output = elementwise(torch.mul, weight, to(output, weight.dtype)) + output = elementwise(torch.mul, weight, to(output, orig_dtype)) return output diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index b66d3be1d..f88684273 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -52,9 +52,7 @@ def qlinear_tensor_scaled( if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: if x_layout.qs.dtype == torch.float8_e4m3fnuz: # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to( - torch.float16 - ) + return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) else: return NotImplemented diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 82ace179c..a698ccb06 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -769,18 +769,25 @@ def _module_register_buffer_trampoline( @overridable -def rms_norm(x: AnyTensor, weight: AnyTensor, *, epsilon: float) -> AnyTensor: +def rms_norm( + x: AnyTensor, weight: AnyTensor, *, epsilon: float, orig_dtype: torch.dtype +) -> AnyTensor: """Computes the full, unbiased RMS normalization of an input.""" raise NotImplementedError @rms_norm.trampoline def _rms_norm_trampoline( - d: SignatureDispatcher, x: AnyTensor, weight: AnyTensor, *, epsilon: float + d: SignatureDispatcher, + x: AnyTensor, + weight: AnyTensor, + *, + epsilon: float, + orig_dtype: torch.dtype, ): tensors = (x, weight) for override in d.find_overrides(tensors): - result = override(x, weight, epsilon=epsilon) + result = override(x, weight, epsilon=epsilon, orig_dtype=orig_dtype) if result is not NotImplemented: return override, result else: diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 7ccae9deb..75071e286 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -3,7 +3,6 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - import os import sys import subprocess @@ -94,6 +93,8 @@ def __init__( block_seq_stride: int, iree_hal_target_device: str, use_attention_mask: bool = False, + activation_dtype: str = "float16", + attention_dtype: str = "float16", ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -106,6 +107,8 @@ def __init__( self.tensor_parallelism_size = tensor_parallelism_size self.block_seq_stride = block_seq_stride self.use_attention_mask = use_attention_mask + self.activation_dtype = activation_dtype + self.attention_dtype = attention_dtype def timeit(func): def wrapper(*args, **kwargs): @@ -183,6 +186,8 @@ def export_to_mlir( f"--output-config={json_path}", f"--bs={str(self.batch_size)}", f"--block-seq-stride={self.block_seq_stride}", + f"--attention-dtype={self.attention_dtype}", + f"--activation-dtype={self.activation_dtype}", ] if skip_decode: export_args.append("--skip-decode") diff --git a/sharktank/sharktank/utils/patching.py b/sharktank/sharktank/utils/patching.py index aee70fd1e..62136318c 100644 --- a/sharktank/sharktank/utils/patching.py +++ b/sharktank/sharktank/utils/patching.py @@ -65,7 +65,7 @@ def after_forward(self, module_name: str, module: torch.nn.Module, results): del self.tensors[module_name] self.duplicate_tensors[module_name] = 0 self.tensors[f"{module_name}#0"] = orig_dup - elif module_name in self.duplicate_tensors: + if module_name in self.duplicate_tensors: index = self.duplicate_tensors[module_name] + 1 self.duplicate_tensors[module_name] = index self.tensors[f"{module_name}#{index}"] = result_tensor diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 1a362415a..c78ee9e29 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -100,6 +100,8 @@ def setUp(self): attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=32, + activation_dtype="bfloat16", + attention_dtype="float8_e4m3fnuz", ) self.prefill_args_bs4_128_stride_32_f16 = ( self.artifacts_dir / "prefill_args_bs4_128_stride_32" @@ -279,7 +281,11 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_2048(self): cwd=self.repo_root, ) - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + @pytest.mark.xfail( + reason="Benchmark inputs not configured yet.", + strict=False, + raises=IreeBenchmarkException, + ) def testBenchmark8B_fp8_Non_Decomposed(self): output_file_name = self.dir_path_8b / "fp8_torch" output_mlir = self.llama8b_fp8_torch_sdpa_artifacts.create_file( @@ -307,7 +313,7 @@ def testBenchmark8B_fp8_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path_fp8, - args=self.iree_run_prefill_args, + args=self.iree_run_prefill_args_fp8, cwd=self.repo_root, ) # benchmark decode @@ -315,7 +321,7 @@ def testBenchmark8B_fp8_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path_fp8, - args=self.iree_run_decode_args, + args=self.iree_run_decode_args_fp8, cwd=self.repo_root, ) diff --git a/sharktank/tests/models/llama/quark_parity_test.py b/sharktank/tests/models/llama/quark_parity_test.py new file mode 100644 index 000000000..1ffdffd30 --- /dev/null +++ b/sharktank/tests/models/llama/quark_parity_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from safetensors import safe_open +import torch +import unittest +import pytest + + +@pytest.mark.skip(reason="need to generate values to compare against") +class QuarkParityTest(unittest.TestCase): + def test_compare_against_quark(self): + def both(key, index=None): + o = ours[key] + t = theirs[key] + if index is None: + return o, t + else: + return o[index], t[index] + + mapping = dict() + for i in range(32): + hf = f"model.layers.{i}" + gg = f"attn_blocks.{i}" + base_pairs = [ + [f"{hf}.input_layernorm", f"{gg}.attn.attn_norm"], + [f"{hf}.self_attn.k_proj", f"{gg}.attn.attn_k"], + [f"{hf}.self_attn.q_proj", f"{gg}.attn.attn_q"], + [f"{hf}.self_attn.v_proj", f"{gg}.attn.attn_v"], + [f"{hf}.self_attn.o_proj", f"{gg}.attn.attn_output"], + [f"{hf}.post_attention_layernorm", f"{gg}.ffn_norm"], + [f"{hf}.mlp.down_proj", f"{gg}.ffn.ffn_down"], + [f"{hf}.mlp.gate_proj", f"{gg}.ffn.ffn_gate"], + [f"{hf}.mlp.up_proj", f"{gg}.ffn.ffn_up"], + ] + for a, b in base_pairs: + mapping[a] = b + mapping[a + "_input_0"] = b + "_input_0" + + ours = dict() + with safe_open("../ours_newest_prefill.safetensors", "pytorch") as st: + for key in st.keys(): + ours[key] = st.get_tensor(key) + + theirs = dict() + with safe_open("../theirs2.safetensors", "pytorch") as st: + for key in st.keys(): + if key in mapping: + theirs[mapping[key]] = st.get_tensor(key) + + test_layers = [v for k, v in mapping.items()] + for lyr in test_layers: + name = lyr + if name in ours.keys() and name != "freqs": + o, t = both(name) + torch.testing.assert_close(o, t, atol=0, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index c3ee2d339..ded0c0900 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -248,7 +248,7 @@ def _ref(self, x, weight, epsilon): def testTorchImpl(self): t1 = torch.rand(16, 128, dtype=torch.float32) t2 = torch.rand(16, 128, dtype=torch.float32) - result = ops.rms_norm(t1, t2, epsilon=1e-10) + result = ops.rms_norm(t1, t2, epsilon=1e-10, orig_dtype=torch.float32) actual = self._ref(t1, t2, epsilon=1e-10) torch.testing.assert_close(actual, result) @@ -256,7 +256,7 @@ def testTorchPrimitiveWeightImpl(self): t1 = torch.rand(16, 128, dtype=torch.float32) t2 = torch.rand(16, 128, dtype=torch.float32) t2_pt = DefaultPrimitiveTensor(data=t2) - result = ops.rms_norm(t1, t2_pt, epsilon=1e-10) + result = ops.rms_norm(t1, t2_pt, epsilon=1e-10, orig_dtype=torch.float32) actual = self._ref(t1, t2, epsilon=1e-10) torch.testing.assert_close(actual, result)