Skip to content

Commit

Permalink
Update missing dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Mar 29, 2020
1 parent 32ee3b3 commit 9b34912
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pycave/bayes/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
import torch.nn as nn
import torch.distributions as dist
from sklearn.cluster import KMeans
import pyblaze.nn as xnn
from pyblaze.utils.stdio import ProgressBar
from pyblaze.utils.torch import to_one_hot
from sklearn.cluster import KMeans
from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars
from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars, \
to_one_hot

class GMMConfig(xnn.Config):
"""
Expand Down
19 changes: 19 additions & 0 deletions pycave/bayes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,22 @@ def power_iteration(A, eps=1e-7, max_iterations=100):
break

return v


def to_one_hot(X, n):
"""
Creates a one-hot matrix from a set of indices.
Parameters:
-----------
- X: torch.Tensor [N, D]
The indices to convert into one-hot vectors.
- n: int
The number of entries in the one-hot vectors.
Returns:
--------
- torch.Tensor [N, D, n]
The one-hot matrix.
"""
return torch.eye(n, device=X.device)[X]

0 comments on commit 9b34912

Please sign in to comment.