diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 811f84991998..394ed7be52b0 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -85,10 +85,10 @@ "nn.functional.upsample_nearest", "nonzero", "nonzero_static", - "norm", + #"norm", "normal", "ormqr", - "pca_lowrank", + #"pca_lowrank", "pinverse", "polar", "polygamma", @@ -156,14 +156,14 @@ "linalg.eigvalsh": (5e1, 3e0), "linalg.pinv": (8e-1, 2e0), "linalg.svd": (1e0, 1e0), - "matrix_exp": (2e-1, 2e-4)} + "matrix_exp": (2e-1, 2e-4), + "pca_lowrank" :(1e-6, 1e-5) } def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) output2_cpu = output2.detach().cpu() if output1.layout != torch.strided: - # We only compare dense tensors. We dont currently support sparse tensors output1 = output1.to_dense() if check_output: torch.testing.assert_close( diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a8685f68e86b..af8e22c1c172 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4434,3 +4434,54 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor translation=translation, antialias=antialias, ) + +@op(torch.ops.aten.norm) +def _norm(input, p='fro', dim=None, keepdim=False, dtype=None): + if p == 'fro': + p = 2 + if dim is None: + return jnp.linalg.norm(input, ord=p) + return jnp.linalg.norm(input, ord=p, axis=dim, keepdims=keepdim) + +@op(torch.pca_lowrank) +def _pca_lowrank(A, q=None, center=True, niter=2): + *batch_dims, m, n = A.shape + A = A.astype(jnp.float32) + + if q is None: + q = min(6, m, n) + + if center: + mean = jnp.mean(A, axis=-2, keepdims=True) + A = A - mean + key = jax.random.PRNGKey(0) + + omega = jax.random.normal(key, (n, q), dtype=A.dtype) + Y = jnp.matmul(A, omega) + Q, _ = jnp.linalg.qr(Y, mode='reduced') + + for _ in range(niter): + Z = jnp.matmul(A.T, Q) + Q, _ = jnp.linalg.qr(jnp.matmul(A, Z), mode='reduced') + + B = jnp.matmul(Q.T, A) + U_tilde, S, Vt = jnp.linalg.svd(B, full_matrices=False) + U = jnp.matmul(Q, U_tilde) + # Enforce consistent signs by making the first element of each U column positive + signs = jnp.sign(U[0, :]) + signs = jnp.where(signs == 0, 1, signs) # Handle zeros to avoid NaNs + U *= signs + Vt *= signs[:, None] + V = Vt.T + + U = U[..., :, :q] + S = S[..., :q] + V = V[..., :, :q] + + U = U.astype(A.dtype) + S = S.astype(A.dtype) + V = V.astype(A.dtype) + return U, S, V + + +