From 5c42f6c823ea3ff976b84c21c3a981db3042dee6 Mon Sep 17 00:00:00 2001 From: alfatum Date: Mon, 18 Dec 2023 11:14:51 +0300 Subject: [PATCH] wrong device in ttv_tensor.py bug fixed --- pyCP_APR/torch_backend/ttv_sptensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyCP_APR/torch_backend/ttv_sptensor.py b/pyCP_APR/torch_backend/ttv_sptensor.py index a31cc8f..5062325 100644 --- a/pyCP_APR/torch_backend/ttv_sptensor.py +++ b/pyCP_APR/torch_backend/ttv_sptensor.py @@ -26,8 +26,8 @@ def ttv(M, vecs, dims=[]): product of KRUSKAL tensor X with a (column) vector vecs. """ - dims = tr.arange(M.Dimensions) - vidx = tr.arange(M.Dimensions) + dims = tr.arange(M.Dimensions).to(M.device) + vidx = tr.arange(M.Dimensions).to(M.device) combined = tr.cat((dims, vidx)) uniques, counts = combined.unique(return_counts=True)