Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed + DirichletMulticlassClassification gets cryptic shape mismatch errors for some datasets #322

Open
rg936672 opened this issue Jul 16, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@rg936672
Copy link
Contributor

What's the problem?

Using Distributed on DirichletMulticlassClassification with 3 experts and 4 classes raises a RuntimeError complaining about mismatched shapes when run on a synthetic dataset with fewer than 40 points.

How can we reproduce the issue?

import numpy as np
from gpytorch.likelihoods import DirichletClassificationLikelihood

from vanguard.classification import DirichletMulticlassClassification
from vanguard.datasets.classification import MulticlassGaussianClassificationDataset
from vanguard.distribute import Distributed
from vanguard.kernels import ScaledRBFKernel
from vanguard.vanilla import GaussianGPController

@Distributed(n_experts=3, rng=np.random.default_rng(1234))
@DirichletMulticlassClassification(num_classes=4)
class MyController(GaussianGPController):
    pass


# seems to fail for any `num_train_points` less than 40
dataset = MulticlassGaussianClassificationDataset(
    num_train_points=20, num_test_points=4, num_classes=4, rng=np.random.default_rng(1234)
)

controller = MyController(
    train_x=dataset.train_x,
    train_y=dataset.train_y,
    y_std=dataset.train_y_std,
    kernel_class=ScaledRBFKernel,
    kernel_kwargs={"batch_shape": (4,)},
    likelihood_class=DirichletClassificationLikelihood,
    likelihood_kwargs={"learn_additional_noise": True},
    rng=np.random.default_rng(1234),
)

controller.fit(1)

Python version

3.12

Package version

2.1.0

Operating system

Windows Server 2022 Datacenter 21H2

Other packages

No response

Relevant log output

Traceback (most recent call last):
  File "example.py", line 31, in <module>
    controller.fit(1)
  File "<PROJECT DIR>\vanguard\distribute\decorator.py", line 162, in fit
    loss = super().fit(n_sgd_iters, gradient_every=gradient_every)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\base\gpcontroller.py", line 134, in fit
    loss = self._sgd_round(n_iters=n_sgd_iters, gradient_every=gradient_every)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\base\basecontroller.py", line 382, in _sgd_round
    raise err
  File "<PROJECT DIR>\vanguard\base\basecontroller.py", line 365, in _sgd_round
    loss = self._single_optimisation_step(train_x, train_y, retain_graph=iter_num < n_iters - 1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\base\basecontroller.py", line 406, in _single_optimisation_step
    loss = self._loss(x, y)
           ^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\classification\dirichlet.py", line 211, in _loss
    return super()._loss(train_x, train_y).sum()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\base\basecontroller.py", line 427, in _loss
    return -self._mll(output, train_y.squeeze(dim=-1))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PROJECT DIR>\vanguard\base\basecontroller.py", line 597, in __call__
    result = super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\mlls\exact_marginal_log_likelihood.py", line 66, in forward
    output = self.likelihood(function_dist, *params, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\gaussian_likelihood.py", line 469, in __call__
    return super().__call__(input, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\likelihood.py", line 76, in __call__
    return self.marginal(input, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\gaussian_likelihood.py", line 461, in marginal
    return super().marginal(function_dist, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\gaussian_likelihood.py", line 359, in marginal
    return super().marginal(function_dist, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\gaussian_likelihood.py", line 116, in marginal
    noise_covar = self._shaped_noise_covar(mean.shape, *params, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\gpytorch\likelihoods\gaussian_likelihood.py", line 345, in _shaped_noise_covar
    res = res + self.second_noise_covar(*params, shape=shape, **kwargs)
          ~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  File "<VIRTUAL ENV>\Lib\site-packages\linear_operator\operators\diag_linear_operator.py", line 35, in __add__
    return self.add_diagonal(other._diag)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\linear_operator\operators\diag_linear_operator.py", line 130, in add_diagonal
    shape = torch.broadcast_shapes(self._diag.shape, diag.shape)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<VIRTUAL ENV>\Lib\site-packages\torch\functional.py", line 132, in broadcast_shapes
    raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape
@rg936672 rg936672 added bug Something isn't working new Something yet to be discussed by development team labels Jul 16, 2024
@bk958178 bk958178 removed the new Something yet to be discussed by development team label Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants