diff --git a/inference/interact/fbrs/inference/predictors/brs_functors.py b/inference/interact/fbrs/inference/predictors/brs_functors.py index 92a5d99..222b61c 100644 --- a/inference/interact/fbrs/inference/predictors/brs_functors.py +++ b/inference/interact/fbrs/inference/predictors/brs_functors.py @@ -72,7 +72,7 @@ def __call__(self, x): self._last_mask = current_mask loss.backward() - f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.cfloat) return [f_val, f_grad]