Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

norm and pca_lowrank op info test from issue #7528 #8384

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
norm and pca_lowrank op info test
vyom1611 committed Nov 14, 2024
commit 165cf00037489545ad505e93a367ea9afa729a74
8 changes: 4 additions & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
@@ -85,10 +85,10 @@
"nn.functional.upsample_nearest",
"nonzero",
"nonzero_static",
"norm",
#"norm",
"normal",
"ormqr",
"pca_lowrank",
#"pca_lowrank",
"pinverse",
"polar",
"polygamma",
@@ -156,14 +156,14 @@
"linalg.eigvalsh": (5e1, 3e0),
"linalg.pinv": (8e-1, 2e0),
"linalg.svd": (1e0, 1e0),
"matrix_exp": (2e-1, 2e-4)}
"matrix_exp": (2e-1, 2e-4),
"pca_lowrank" :(1e-6, 1e-5) }

def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_output=True):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
if output1.layout != torch.strided:
# We only compare dense tensors. We dont currently support sparse tensors
output1 = output1.to_dense()
if check_output:
torch.testing.assert_close(
51 changes: 51 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -4434,3 +4434,54 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
translation=translation,
antialias=antialias,
)

@op(torch.ops.aten.norm)
def _norm(input, p='fro', dim=None, keepdim=False, dtype=None):
if p == 'fro':
p = 2
if dim is None:
return jnp.linalg.norm(input, ord=p)
return jnp.linalg.norm(input, ord=p, axis=dim, keepdims=keepdim)

@op(torch.pca_lowrank)
def _pca_lowrank(A, q=None, center=True, niter=2):
*batch_dims, m, n = A.shape
A = A.astype(jnp.float32)

if q is None:
q = min(6, m, n)

if center:
mean = jnp.mean(A, axis=-2, keepdims=True)
A = A - mean
key = jax.random.PRNGKey(0)

omega = jax.random.normal(key, (n, q), dtype=A.dtype)
Y = jnp.matmul(A, omega)
Q, _ = jnp.linalg.qr(Y, mode='reduced')

for _ in range(niter):
Z = jnp.matmul(A.T, Q)
Q, _ = jnp.linalg.qr(jnp.matmul(A, Z), mode='reduced')

B = jnp.matmul(Q.T, A)
U_tilde, S, Vt = jnp.linalg.svd(B, full_matrices=False)
U = jnp.matmul(Q, U_tilde)
# Enforce consistent signs by making the first element of each U column positive
signs = jnp.sign(U[0, :])
signs = jnp.where(signs == 0, 1, signs) # Handle zeros to avoid NaNs
U *= signs
Vt *= signs[:, None]
V = Vt.T

U = U[..., :, :q]
S = S[..., :q]
V = V[..., :, :q]

U = U.astype(A.dtype)
S = S.astype(A.dtype)
V = V.astype(A.dtype)
return U, S, V