Skip to content

Commit

Permalink
Add decoder method to trVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Mar 14, 2021
1 parent 038ea88 commit ee61640
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions pyroved/models/trvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,20 @@ def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
z_scale = z[:, self.z_dim:]
return z_loc, z_scale

def decode(self, z: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor:
"""
Decodes a batch of latent coordnates
"""
if y is not None:
z = torch.cat([z.to(self.device), y.to(self.device)], -1)
z = [z]
if self.coord > 0:
grid = self.grid.expand(z.shape[0], *self.grid.shape)
z = z.append(grid.to(self.device))
with torch.no_grad():
loc = self.decoder_net(*z)
return loc

def manifold2d(self, d: int, plot: bool = True,
**kwargs: Union[str, int]) -> torch.Tensor:
"""
Expand Down

0 comments on commit ee61640

Please sign in to comment.