diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index db8a76d..051fa1a 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -307,7 +307,7 @@ def fit_covariance_matrices( if torch.is_tensor(obj) or ( hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device( "cuda"): - print(type(obj), obj.size()) + print(type(obj), obj.size(), obj.device) except: pass @@ -345,7 +345,7 @@ def fit_covariance_matrices( if torch.is_tensor(obj) or ( hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device( "cuda"): - print(type(obj), obj.size()) + print(type(obj), obj.size(), obj.device) except: pass