Skip to content

Commit

Permalink
code upload
Browse files Browse the repository at this point in the history
  • Loading branch information
eddardd committed Dec 22, 2024
1 parent 61140b3 commit 373a03e
Show file tree
Hide file tree
Showing 9 changed files with 1,550 additions and 2 deletions.
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,30 @@
# gmm-otda
Optimal Transport for Domain Adaptation through Gaussian Mixture Models
# Optimal Transport for Domain Adaptation through Gaussian Mixture Models

This is the official repository for the paper [Optimal Transport for Domain Adaptation through Gaussian Mixture Models](https://openreview.net/forum?id=DCAeXwLenB), accepted in TMLR. Our paper uses the GMM-OTDA framework of (Delon and Desolneux, 2020) for domain adaptation, through 2 strategies,

- Mapping estimation, which maps points in the source domain towards the target domain using the GMMs,
- Label propagation, which estimates labels for the target domain GMM components.

You can run our code using,

```
python visda.py --base_path=PATH_TO_DATA --features="vit" --clusters_per_class="4" --reg_e=0.1
```

# Citation

```
@article{
montesuma2024optimal,
title={Optimal Transport for Domain Adaptation through Gaussian Mixture Models},
author={Montesuma, Eduardo Fernandes and Mboula, Fred Maurice Ngol{\`e} and Souloumiac, Antoine},
journal={Transactions on Machine Learning Research},
year={2024},
url={https://openreview.net/forum?id=DCAeXwLenB},
note={Under review}
}
```

# References

- Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
387 changes: 387 additions & 0 deletions Toy Example.ipynb

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .gmm import em_gmm
from .gmm import conditional_em_gmm
from .prob_utils import diag_gmm_log_probs
from .gmm_otda import GMMOTDA
from .vis import plot_cov_ellipse
from .models import (
ShallowNeuralNet,
WeightedShallowNeuralNet
)

__all__ = [
em_gmm,
conditional_em_gmm,
diag_gmm_log_probs,
plot_cov_ellipse,
GMMOTDA,
ShallowNeuralNet,
WeightedShallowNeuralNet
]
47 changes: 47 additions & 0 deletions src/gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from sklearn.mixture import GaussianMixture


def em_gmm(X, Y, n_clusters, random_state=None, dtype=torch.float32):
"""Fits a GMM using the expectation-maximization algorithm."""
clustering = GaussianMixture(
n_components=n_clusters,
covariance_type='diag',
random_state=random_state).fit(X)
w = torch.from_numpy(clustering.weights_).to(dtype)
m = torch.from_numpy(clustering.means_).to(dtype)
v = torch.from_numpy(clustering.covariances_).to(dtype)

return w, m, v, None


def conditional_em_gmm(X, Y, n_clusters, random_state=None,
dtype=torch.float32):
"""Fits GMMs on the conditoinals of P(X|Y)
using expectation maximization"""
n_classes = Y.shape[1]
w_s, m_s, v_s, y_s = [], [], [], []
for c in Y.argmax(dim=1).unique():
ind = torch.where(Y.argmax(dim=1) == c)[0]
wc, mc, vc, _ = em_gmm(
X=X[ind],
Y=None,
n_clusters=n_clusters,
random_state=random_state,
dtype=dtype)
w_s.append(wc)
m_s.append(mc)
v_s.append(vc)
y_s.append(
torch.nn.functional.one_hot(
torch.Tensor([c] * len(mc)).long(),
num_classes=n_classes
).to(dtype)
)
w_s = torch.cat(w_s).to(dtype)
w_s /= w_s.sum()
m_s = torch.cat(m_s, dim=0).to(dtype)
v_s = torch.cat(v_s, dim=0).to(dtype)
y_s = torch.cat(y_s, dim=0).to(dtype)

return w_s, m_s, v_s, y_s
295 changes: 295 additions & 0 deletions src/gmm_otda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import torch
import numpy as np
from src.prob_utils import diag_gmm_log_probs


class GMMOTDA:
"""Gaussian Mixture Model-based Optimal Transport for Domain Adaptation"""

def __init__(self,
ot_solver,
clustering_source,
clustering_target,
min_var=0.01):
self.ot_solver = ot_solver
self.clustering_source = clustering_source
self.clustering_target = clustering_target
self.min_var = min_var

self.stds_src = None
self.weights_src = None
self.labels_src = None
self.centroids_src = None

self.stds_tgt = None
self.labels_tgt = None
self.weights_tgt = None
self.centroids_tgt = None
self.estimated_labels_tgt = None

self.ot_plan = None
self.n_dim = None
self.n_classes = None

self.fitted_gmm = False
self.fitted_ot = False

def fit_gmms(self, Xs, Ys, Xt, Yt=None):
"""Fit GMMs to source and target domain data"""
self.n_dim = Xs.shape[1]
self.n_classes = Ys.shape[1]

w_s, m_s, v_s, y_s = self.clustering_source(Xs, Ys)
v_s[v_s < self.min_var] = self.min_var
self.weights_src = w_s
self.stds_src = v_s ** 0.5
self.labels_src = y_s
self.centroids_src = m_s

w_t, m_t, v_t, y_t = self.clustering_target(Xt, Yt)
v_t[v_t < self.min_var] = self.min_var
self.labels_tgt = y_t
self.stds_tgt = v_t ** 0.5
self.weights_tgt = w_t
self.centroids_tgt = m_t

self.fitted_gmm = True

return self

def fit_ot(self):
"""Solves the GMM-OT problem"""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be previously called.")

C = torch.cdist(self.centroids_src, self.centroids_tgt, p=2) ** 2 + \
torch.cdist(self.stds_src, self.stds_tgt, p=2) ** 2
self.ot_plan = self.ot_solver(
self.weights_src,
self.weights_tgt,
C
)

self.estimated_labels_tgt = torch.mm(
(self.ot_plan / self.ot_plan.sum(dim=0)[None, :]).T,
self.labels_src)

self.fitted_ot = True
return self

def fit(self, Xs, Ys, Xt, Yt=None):
"""Fit pipeline. First, fits GMMs, then, fits the GMM-OT problem"""
return self.fit_gmms(Xs, Ys, Xt, Yt).fit_ot()

def predict_target_components(self, X, return_proba=False):
"""Performs k = argmax P_T(K|X) (return_proba = False)
or computes the probabiliities P_T(K|X) (return_proba = True)."""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")
if not self.fitted_ot:
raise ValueError("Expected 'fit_ot' to be called previously.")

log_probs = diag_gmm_log_probs(
X=X,
weights=self.weights_tgt,
means=self.centroids_tgt,
stds=self.stds_tgt
)
log_proba_components = (
log_probs - log_probs.logsumexp(dim=0)[None, :])
if return_proba:
return log_proba_components.exp()
return log_proba_components.argmax(dim=0)

def predict_source_components(self, X, return_proba=False):
"""Performs k = argmax P_S(K|X) (return_proba = False)
or computes the probabiliities P_S(K|X) (return_proba = True)."""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")

log_probs = diag_gmm_log_probs(
X=X,
weights=self.weights_src,
means=self.centroids_src,
stds=self.stds_src
)
log_proba_components = (
log_probs - log_probs.logsumexp(dim=0)[None, :])
if return_proba:
return log_proba_components.exp()
return log_proba_components.argmax(dim=0)

def sample_from_source(self, n):
"""Samples from the source domain GMM"""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")
Xsyn = []
Ysyn = []

for _ in range(n):
k = np.random.choice(
np.arange(len(self.weights_src)),
p=(self.weights_src / self.weights_src.sum()).numpy())

_x = self.stds_src[k, :] * np.random.randn(self.n_dim) + \
self.centroids_src[k, :]
_y = torch.nn.functional.one_hot(
self.labels_src[k, :].argmax(),
num_classes=self.n_classes).float()

Xsyn.append(_x)
Ysyn.append(_y)
Xsyn = torch.stack(Xsyn).float()
Ysyn = torch.stack(Ysyn).float()
return Xsyn, Ysyn

def sample_from_target(self, n, use_estimated_labels=True):
"""Samples from the target domain GMM"""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")
if use_estimated_labels and not self.fitted_ot:
raise ValueError("Expected 'fit_ot' to be called previously.")
if not use_estimated_labels and not self.labels_tgt:
raise ValueError(
"If not using estimated labels, expects target GMM to"
" be fitted using labels."
)
Xsyn = []
Ysyn = []

for _ in range(n):
k = np.random.choice(
np.arange(len(self.weights_tgt)),
p=(self.weights_tgt / self.weights_tgt.sum()).numpy())

_x = self.stds_tgt[k, :] * np.random.randn(self.n_dim) + \
self.centroids_tgt[k, :]
if use_estimated_labels:
_y = torch.nn.functional.one_hot(
self.estimated_labels_tgt[k, :].argmax(),
num_classes=self.n_classes).float()
else:
_y = torch.nn.functional.one_hot(
self.labels_tgt[k, :].argmax(),
num_classes=self.n_classes).float()

Xsyn.append(_x)
Ysyn.append(_y)
Xsyn = torch.stack(Xsyn).float()
Ysyn = torch.stack(Ysyn).float()
return Xsyn, Ysyn

def predict_source_labels(self, X):
"""Predicts class labels using the source GMM"""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")
proba_components = self.predict_source_components(X, return_proba=True)
cluster_labels = torch.mm(
self.labels_src.T, proba_components).T
return cluster_labels

def predict_target_labels(self, X, use_estimated_labels=True):
"""Predicts class labels using the target GMM"""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")
if not self.fitted_ot:
raise ValueError("Expected 'fit_ot' to be called previously.")
proba_components = self.predict_target_components(X, return_proba=True)
if use_estimated_labels:
cluster_labels = torch.mm(
self.estimated_labels_tgt.T, proba_components).T
else:
cluster_labels = torch.mm(
self.labels_tgt.T, proba_components).T
return cluster_labels

def compute_source_nll(self, X):
"""Computes the Negative Log-Likelihood (NLL) using the
source domain GMM."""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")

log_probs = diag_gmm_log_probs(
X=X,
weights=self.weights_src,
means=self.centroids_src,
stds=self.stds_src
)
return - log_probs.logsumexp(dim=0).mean()

def compute_target_nll(self, X):
"""Computes the Negative Log-Likelihood (NLL) using the
target domain GMM."""
if not self.fitted_gmm:
raise ValueError("Expected 'fit_gmm' to be called previously.")

log_probs = diag_gmm_log_probs(
X=X,
weights=self.weights_tgt,
means=self.centroids_tgt,
stds=self.stds_tgt
)
return - log_probs.logsumexp(dim=0).mean()

def transport_samples(self, X, Y, numel=None):
"""Computes the weighted map from the source to the target domain."""
if numel is None:
numel = self.ot_plan.shape[0] + self.ot_plan.shape[1] - 1
q = np.quantile(self.ot_plan.flatten(),
1 - numel / self.ot_plan.numel())
ind_s, ind_t = np.where(self.ot_plan > q)

transp_w = []
transp_X = []
transp_y = []

components_s = self.predict_source_components(X)
for i_s, i_t in zip(ind_s, ind_t):
idx = np.where(components_s == i_s)[0]
x = X[idx]
y = Y[idx]

w = self.ot_plan[i_s, i_t]
A = self.stds_tgt[i_t] / (self.stds_src[i_s] + 1e-9)
b = self.centroids_tgt[i_t] - self.centroids_src[i_s] * A

transp_w.append(torch.Tensor([w] * len(x)))
transp_X.append(x * A + b)
transp_y.append(y)
transp_w = torch.cat(transp_w, dim=0)
transp_X = torch.cat(transp_X, dim=0)
transp_y = torch.cat(transp_y, dim=0)

return transp_w, transp_X, transp_y

def rand_transport(self, X, Y, numel=None):
"""Computes the rand transport of (Delon and Desolneux, 2020) between
source and target domains."""
proba_components_src = self.predict_source_components(
X, return_proba=True).T
sampling_probs = torch.zeros([
len(X), len(self.weights_src), len(self.weights_tgt)])
for k1 in range(len(self.weights_src)):
for k2 in range(len(self.weights_tgt)):
sampling_probs[:, k1, k2] = (
(self.ot_plan[k1, k2] / self.weights_src[k1]) *
proba_components_src[:, k1])
sampling_probs = sampling_probs.numpy()

indices = np.arange(len(self.ot_plan.flatten()))
indices_PQ = np.array([
(k1, k2)
for k1 in range(self.ot_plan.shape[0])
for k2 in range(self.ot_plan.shape[1])])
mapped_x = []

for pi, xi in zip(sampling_probs, X):
idx = np.random.choice(indices, p=pi.flatten())
k1, k2 = indices_PQ[idx]

A = self.stds_tgt[k2] / (self.stds_src[k1] + 1e-9)
b = self.centroids_tgt[k2] - self.centroids_src[k1] * A

mapped_x.append(xi * A + b)
mapped_x = torch.stack(mapped_x)
return mapped_x, Y
Loading

0 comments on commit 373a03e

Please sign in to comment.