From 3e5eb22e3c9e9b2c2aa35f53960c8c7e234cabc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 21 Nov 2024 12:25:16 +0100 Subject: [PATCH] fix: suppress untyped storage torch warning --- changelog.md | 4 ++++ foldedtensor/__init__.py | 40 +++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/changelog.md b/changelog.md index c964610..fd44420 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,9 @@ # Changelog +## Unreleased + +- Fix `storage` torch warning + ## v0.3.5 - Support hashing the `folded_tensor.length` field (via a UserList), which is convenient for caching diff --git a/foldedtensor/__init__.py b/foldedtensor/__init__.py index 10bf098..d7cecda 100644 --- a/foldedtensor/__init__.py +++ b/foldedtensor/__init__.py @@ -1,4 +1,5 @@ import typing +import warnings from collections import UserList from multiprocessing.reduction import ForkingPickler from typing import List, Optional, Sequence, Tuple, Union @@ -420,23 +421,28 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]): def reduce_foldedtensor(self: FoldedTensor): - return ( - FoldedTensor, - ( - self.data.as_tensor(), - self.lengths, - self.data_dims, - self.full_names, - self.indexer.clone() - if self.indexer.is_shared() and self.indexer.storage().is_cuda - else self.indexer, - None - if self._mask is not None - and self._mask.is_shared() - and self._mask.storage().is_cuda - else self._mask, - ), - ) + with warnings.catch_warnings(): + warnings.simplefilter( + "ignore", + category=UserWarning, + ) + return ( + FoldedTensor, + ( + self.data.as_tensor(), + self.lengths, + self.data_dims, + self.full_names, + self.indexer.clone() + if self.indexer.is_shared() and self.indexer.storage().is_cuda + else self.indexer, + None + if self._mask is not None + and self._mask.is_shared() + and self._mask.storage().is_cuda + else self._mask, + ), + ) ForkingPickler.register(FoldedTensor, reduce_foldedtensor)