diff --git a/pycave/bayes/gmm.py b/pycave/bayes/gmm.py index 3b45f6b..f4db607 100644 --- a/pycave/bayes/gmm.py +++ b/pycave/bayes/gmm.py @@ -3,11 +3,11 @@ import torch import torch.nn as nn import torch.distributions as dist +from sklearn.cluster import KMeans import pyblaze.nn as xnn from pyblaze.utils.stdio import ProgressBar -from pyblaze.utils.torch import to_one_hot -from sklearn.cluster import KMeans -from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars +from .utils import log_normal, log_responsibilities, max_likeli_means, max_likeli_covars, \ + to_one_hot class GMMConfig(xnn.Config): """ diff --git a/pycave/bayes/utils.py b/pycave/bayes/utils.py index 210ad9b..fc1db0d 100644 --- a/pycave/bayes/utils.py +++ b/pycave/bayes/utils.py @@ -193,3 +193,22 @@ def power_iteration(A, eps=1e-7, max_iterations=100): break return v + + +def to_one_hot(X, n): + """ + Creates a one-hot matrix from a set of indices. + + Parameters: + ----------- + - X: torch.Tensor [N, D] + The indices to convert into one-hot vectors. + - n: int + The number of entries in the one-hot vectors. + + Returns: + -------- + - torch.Tensor [N, D, n] + The one-hot matrix. + """ + return torch.eye(n, device=X.device)[X]