diff --git a/src/beignet/polynomial/__div.py b/src/beignet/polynomial/__div.py index ea8e171718..5d09158328 100644 --- a/src/beignet/polynomial/__div.py +++ b/src/beignet/polynomial/__div.py @@ -3,8 +3,6 @@ import torch from torch import Tensor -from .__nonzero import _nonzero - def _div( func: Callable, @@ -31,7 +29,25 @@ def _div( def f(x: Tensor) -> Tensor: indicies = torch.flip(x, [0]) - indicies = _nonzero(indicies, size=1) + indicies = torch.nonzero(indicies, as_tuple=False) + + if indicies.shape[0] > 1: + indicies = indicies[:1] + + if indicies.shape[0] < 1: + indicies = torch.concatenate( + [ + indicies, + torch.full( + [ + 1 - indicies.shape[0], + indicies.shape[1], + ], + 0, + ), + ], + 0, + ) return x.shape[0] - 1 - indicies[0][0] diff --git a/src/beignet/polynomial/__init__.py b/src/beignet/polynomial/__init__.py index 13a36a9ce2..27970e0ef5 100644 --- a/src/beignet/polynomial/__init__.py +++ b/src/beignet/polynomial/__init__.py @@ -14,7 +14,6 @@ from .__get_domain import _get_domain from .__map_domain import _map_domain from .__map_parameters import _map_parameters -from .__nonzero import _nonzero from .__normed_hermite_e_n import _normed_hermite_e_n from .__normed_hermite_n import _normed_hermite_n from .__nth_slice import _nth_slice @@ -289,22 +288,17 @@ __all__ = [ "_c_series_to_z_series", "_div", - "_evaluate", "_fit", - "_flattened_vandermonde", "_from_roots", "_get_domain", "_map_domain", "_map_parameters", - "_nonzero", "_normed_hermite_e_n", "_normed_hermite_n", "_nth_slice", "_pad_along_axis", "_pow", - "_trim_coefficients", "_trim_sequence", - "_vandermonde", "_z_series_mul", "_z_series_to_c_series", "chebyshev_polynomial_to_polynomial",