From 4df273a3062e78606a222ae55276db1470f2a54d Mon Sep 17 00:00:00 2001 From: Daniel Platz Date: Sun, 8 Dec 2024 01:47:57 +0100 Subject: [PATCH] fixed ValueError cause of negative strides --- src/mokka/mapping/torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mokka/mapping/torch.py b/src/mokka/mapping/torch.py index 4d0f26d..fe20377 100644 --- a/src/mokka/mapping/torch.py +++ b/src/mokka/mapping/torch.py @@ -134,7 +134,7 @@ def get_constellation(self, *args): """ # Test bits B = generate_all_bits(self.m.item()).copy() - bits = torch.from_numpy(B).to(self.weights.device) + bits = torch.from_numpy(B.copy()).to(self.weights.device) logger.debug("bits device: %s", bits.device) out = self.forward(bits) return out @@ -304,7 +304,7 @@ def get_constellation(self, *args): mod_args = torch.tensor(args, dtype=torch.float32) mod_args = mod_args.repeat(2 ** self.m.item(), 1).split(1, dim=-1) B = generate_all_bits(self.m.item()).copy() - bits = torch.from_numpy(B).to(self.map1.weight.device) + bits = torch.from_numpy(B.copy()).to(self.map1.weight.device) logger.debug("bits device: %s", bits.device) out = self.forward(bits, *mod_args).flatten() return out @@ -413,7 +413,7 @@ def get_constellation(self, *args): :returns: tensor of constellation points """ # Test bits - bits = torch.from_numpy(generate_all_bits(self.m.item())).to( + bits = torch.from_numpy(generate_all_bits(self.m.item()).copy()).to( self.real_weights.device ) logger.debug("bits device: %s", bits.device)