Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@DirichletKernelMulticlassClassification with @SetWarp results in "index out of bounds" errors #376

Open
rg936672 opened this issue Aug 20, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@rg936672
Copy link
Contributor

rg936672 commented Aug 20, 2024

What's the problem?

This seems to be because the warp function moves the class indices out of the expected range of values. The other classification decorators seem to work fine with @SetWarp though, which is why I'm raising this as an issue - if all but one work, that seems like an error with the only one that doesn't work!

DirichletMulticlassClassification, BinaryClassification, and CategoricalClassification seem to work fine though.

How can we reproduce the issue?

from gpytorch.kernels.keops import RBFKernel
from gpytorch.means import ZeroMean

from tests.cases import get_default_rng
from vanguard.classification.kernel import DirichletKernelMulticlassClassification
from vanguard.classification.likelihoods import DirichletKernelClassifierLikelihood, GenericExactMarginalLogLikelihood
from vanguard.datasets.classification import MulticlassGaussianClassificationDataset
from vanguard.vanilla import GaussianGPController
from vanguard.warps import SetWarp, warpfunctions


def test():
    @SetWarp(warp_function=warpfunctions.SinhWarpFunction())
    @DirichletKernelMulticlassClassification(num_classes=4)
    class MyController(GaussianGPController):
        pass

    dataset = MulticlassGaussianClassificationDataset(num_train_points=20, num_test_points=8, num_classes=4, rng=get_default_rng())

    controller = MyController(
        train_x=dataset.train_x,
        train_y=dataset.train_y,
        y_std=0,
        mean_class=ZeroMean,
        kernel_class=RBFKernel,
        likelihood_class=DirichletKernelClassifierLikelihood,
        marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
    )

    controller.fit()

Python version

3.12

Package version

2.1.0

Operating system

No response

Other packages

No response

Relevant log output

FAILED                                             [100%]
tests\test_example.py:11 (test)
test_example.py:30: in test
    controller.fit()
..\vanguard\base\gpcontroller.py:134: in fit
    loss = self._sgd_round(n_iters=n_sgd_iters, gradient_every=gradient_every)
..\vanguard\warps\decorator.py:174: in _sgd_round
    loss = super()._sgd_round(n_iters=n_iters, gradient_every=gradient_every)
..\vanguard\base\basecontroller.py:365: in _sgd_round
    loss = self._single_optimisation_step(train_x, train_y, retain_graph=iter_num < n_iters - 1)
..\vanguard\base\basecontroller.py:406: in _single_optimisation_step
    loss = self._loss(x, y)
..\vanguard\warps\decorator.py:194: in _loss
    nmll = super()._loss(train_x, warped_train_y)
..\vanguard\base\basecontroller.py:423: in _loss
    output = self._gp_forward(train_x)
..\vanguard\base\basecontroller.py:451: in _gp_forward
    output = self._gp(x)
..\vanguard\base\basecontroller.py:597: in __call__
    result = super().__call__(*args, **kwargs)
..\vanguard\classification\models.py:139: in __call__
    return DummyKernelDistribution(self._label_tensor(self.train_targets), kernel_matrix)
..\vanguard\classification\models.py:104: in _label_tensor
    return DiagLinearOperator(torch.ones(self.n_classes))[targets.long()]
..\..\..\.virtualenvs\Vanguard-ShqWj01e\Lib\site-packages\linear_operator\operators\_linear_operator.py:2863: in __getitem__
    res = self._getitem(row_index, col_index, *batch_indices)
..\..\..\.virtualenvs\Vanguard-ShqWj01e\Lib\site-packages\linear_operator\operators\_linear_operator.py:316: in _getitem
    return res._getitem(row_index, col_index, *batch_indices)
..\..\..\.virtualenvs\Vanguard-ShqWj01e\Lib\site-packages\linear_operator\operators\interpolated_linear_operator.py:175: in _getitem
    left_interp_indices = left_interp_indices[(*batch_indices, row_index, _noop_index)]
E   IndexError: index 10 is out of bounds for dimension 0 with size 4
@rg936672 rg936672 added bug Something isn't working new Something yet to be discussed by development team labels Aug 20, 2024
@tp832944 tp832944 removed the new Something yet to be discussed by development team label Aug 28, 2024
@tp832944
Copy link
Contributor

Once fixed, disable the corresponding warning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants