From 52eefce8692b3c864f73fdd4528ba57ee57593e8 Mon Sep 17 00:00:00 2001 From: star-dust-ctrl <64853102+star-dust-ctrl@users.noreply.github.com> Date: Sun, 26 May 2024 00:18:51 +0800 Subject: [PATCH] update PDAS --- docs/source/autoapi/solver.rst | 3 +- pytest/test_workflow.py | 3 + skscope/__init__.py | 2 + skscope/base_solver.py | 14 ++- skscope/solver.py | 153 ++++++++++++++++++++++++++++++++- 5 files changed, 169 insertions(+), 6 deletions(-) diff --git a/docs/source/autoapi/solver.rst b/docs/source/autoapi/solver.rst index 87442e2..1813439 100644 --- a/docs/source/autoapi/solver.rst +++ b/docs/source/autoapi/solver.rst @@ -15,7 +15,8 @@ Classes skscope.solver.FobaSolver skscope.solver.ForwardSolver skscope.solver.OMPSolver + skscope.solver.PDASSolver .. autoapimodule:: skscope.solver - :members: ScopeSolver, HTPSolver,IHTSolver,GraspSolver, FobaSolver, ForwardSolver, OMPSolver + :members: ScopeSolver, HTPSolver,IHTSolver,GraspSolver, FobaSolver, ForwardSolver, OMPSolver, PDASSolver diff --git a/pytest/test_workflow.py b/pytest/test_workflow.py index 97177ff..e75a4cf 100644 --- a/pytest/test_workflow.py +++ b/pytest/test_workflow.py @@ -13,6 +13,7 @@ FobaSolver, ForwardSolver, OMPSolver, + PDASSolver, ) import pytest from create_test_model import CreateTestModel @@ -35,6 +36,7 @@ foba_gdt_solver, ForwardSolver, OMPSolver, + PDASSolver, ) solvers_ids = ( "scope", @@ -46,6 +48,7 @@ "FOBA_gdt", "Forward", "OMP", + "PDASSolver", ) diff --git a/skscope/__init__.py b/skscope/__init__.py index 21cc458..91aa7c8 100644 --- a/skscope/__init__.py +++ b/skscope/__init__.py @@ -16,6 +16,7 @@ FobaSolver, ForwardSolver, OMPSolver, + PDASSolver, ) from .base_solver import BaseSolver from .numeric_solver import convex_solver_LBFGS @@ -34,4 +35,5 @@ "PortfolioSelection", "NonlinearSelection", "RobustRegression", + "PDASSolver", ] diff --git a/skscope/base_solver.py b/skscope/base_solver.py index fcdcbbd..8350af1 100644 --- a/skscope/base_solver.py +++ b/skscope/base_solver.py @@ -283,7 +283,7 @@ def solve( for i in range(len(layers) - 1): assert layers[i].out_features == layers[i + 1].in_features assert layers[-1].out_features == self.dimensionality - loss_, grad_ = BaseSolver._set_objective(objective, gradient, jit, layers) + loss_, grad_, hess_ = BaseSolver._set_objective(objective, gradient, jit, layers) p = layers[0].in_features for layer in layers[::-1]: sparsity = layer.transform_sparsity(sparsity) @@ -291,7 +291,9 @@ def solve( preselect = layer.transform_preselect(preselect) else: p = self.dimensionality - loss_, grad_ = BaseSolver._set_objective(objective, gradient, jit) + loss_, grad_, hess_ = BaseSolver._set_objective(objective, gradient, jit) + + self.hess_ = hess_ def loss_fn(params, data): value = loss_(params, data) @@ -461,7 +463,13 @@ def grad_(params, data): if jit: grad_ = jax.jit(grad_) - return loss_, grad_ + def hess_(params, data): + return jax.hessian(loss_)(params, data) + + if jit: + hess_ = jax.jit(hess_) + + return loss_, grad_, hess_ def _solve( self, diff --git a/skscope/solver.py b/skscope/solver.py index a46a1eb..ad6c101 100644 --- a/skscope/solver.py +++ b/skscope/solver.py @@ -627,7 +627,7 @@ def __set_objective_cpp(self, objective, gradient, hessian): return objective def __set_objective_py(self, objective, gradient, hessian, jit, layers=[]): - loss_, grad_ = BaseSolver._set_objective(objective, gradient, jit, layers) + loss_, grad_, hess_ = BaseSolver._set_objective(objective, gradient, jit, layers) # hess if hessian is None: @@ -798,7 +798,7 @@ def _solve( group, ) # init - params = init_params + params = init_params best_suppport_group_tuple = None best_loss = np.inf results = {} # key: tuple of ordered support set, value: params @@ -1641,3 +1641,152 @@ def __init__( random_state=random_state, ) self.use_gradient = True + + + +class PDASSolver(BaseSolver): + r""" + Solve the best subset selection problem with the subset size :math:`k` by Primal-dual active set (PDAS) algorithm. + Specifically, ``PDASSolver`` aims to tackle this problem: :math:`\min_{\beta \in R^p} l(\beta) \text{ s.t. } ||\beta||_0 = k`, where :math:`l(\beta)` is a convex objective function and :math:`s` is the sparsity level. Each element of :math:`x` can be seen as a variable, and the nonzero elements of :math:`x` are the selected variables. + + Parameters + ---------- + dimensionality : int + Dimension of the optimization problem, which is also the total number of variables that will be considered to select or not, denoted as :math:`p`. + sparsity : int or array of int, optional + The sparsity level, which is the number of nonzero elements of the optimal solution, denoted as :math:`s`. + Default is ``range(int(p/log(log(p))/log(p)))``. + sample_size : int, default=1 + Sample size, denoted as :math:`n`. + preselect : array of int, default=[] + An array contains the indexes of variables which must be selected. + step_size : float, default=0.005 + Step size of gradient descent. + numeric_solver : callable, optional + A solver for the convex optimization problem. ``HTPSolver`` will call this function to solve the convex optimization problem in each iteration. + It should have the same interface as ``skscope.convex_solver_LBFGS``. + max_iter : int, default=100 + Maximum number of iterations taken for converging. + group : array of shape (dimensionality,), default=range(dimensionality) + The group index for each variable, and it must be an incremental integer array starting from 0 without gap. + The variables in the same group must be adjacent, and they will be selected together or not. + Here are wrong examples: ``[0,2,1,2]`` (not incremental), ``[1,2,3,3]`` (not start from 0), ``[0,2,2,3]`` (there is a gap). + It's worth mentioning that the concept "a variable" means "a group of variables" in fact. For example,``sparsity=[3]`` means there will be 3 groups of variables selected rather than 3 variables, + and ``always_include=[0,3]`` means the 0-th and 3-th groups must be selected. + cv : int, default=1 + The folds number when use the cross-validation method. + - If ``cv`` = 1, the sparsity level will be chosen by the information criterion. + - If ``cv`` > 1, the sparsity level will be chosen by the cross-validation method. + split_method : callable, optional + A function to get the part of data used in each fold of cross-validation. + Its interface should be ``(data, index) -> part_data`` where ``index`` is an array of int. + cv_fold_id : array of shape (sample_size,), optional + An array indicates different folds in CV, which samples in the same fold should be given the same number. + The number of different elements should be equal to ``cv``. + Used only when `cv` > 1. + random_state : int, optional + The random seed used for cross-validation. + + Attributes + ---------- + params : array of shape(dimensionality,) + The sparse optimal solution. + objective_value: float + The value of objective function on the solution. + support_set : array of int + The indices of selected variables, sorted in ascending order. + + References + ---------- + Wen C H, Zhang A J, Quan S J, Wang X Q. BeSS: An R Package for Best Subset Selection in Linear, Logistic and Cox Proportional Hazards Models[J]. Journal of Statistical Software, 2020, 94(4): 1-24. + + """ + def __init__( + self, + dimensionality, + sparsity=None, + sample_size=1, + *, + preselect=[], + numeric_solver=convex_solver_LBFGS, + max_iter=100, + group=None, + cv=1, + cv_fold_id=None, + split_method=None, + random_state=None, + ): + super().__init__( + dimensionality=dimensionality, + sparsity=sparsity, + sample_size=sample_size, + preselect=preselect, + numeric_solver=numeric_solver, + max_iter=max_iter, + group=group, + cv=cv, + cv_fold_id=cv_fold_id, + split_method=split_method, + random_state=random_state, + ) + + def _solve( + self, + sparsity, + loss_fn, + value_and_grad, + init_support_set, + init_params, + data, + preselect, + group + ): + if sparsity <= preselect.size: + return super()._solve( + sparsity, + loss_fn, + value_and_grad, + init_support_set, + init_params, + data, + preselect, + group, + ) + + support_set_group = np.union1d(preselect, init_support_set) + group_num = len(np.unique(group)) + group_indices = [np.where(group == i)[0] for i in range(group_num)] + if support_set_group.size < sparsity: + diff_num = sparsity - support_set_group.size + all_support = np.arange(group_num) + diff_support = np.setdiff1d(all_support, support_set_group) + rng = np.random.default_rng(seed=self.random_state) + support_set_group = np.union1d(rng.choice(diff_support, diff_num), support_set_group) + support = np.concatenate([group_indices[i] for i in support_set_group]) + + for n_iters in range(self.max_iter): + params = np.zeros_like(init_params) + loss, params = self._numeric_solver( + loss_fn, value_and_grad, params, support, data + ) + g = value_and_grad(params, data)[1] + h = np.diag(np.array(self.hess_(params, data))) + g[support] = 0 + gamma = -g / h + delta = 1/2 * h * np.square(params + gamma) + score = np.array( + [ + np.sum(delta[group_indices[i]]) + for i in range(group_num) + ] + ) + score[preselect] = np.inf + support_set_group_new = np.argpartition(score, -sparsity)[-sparsity:] + support_new = np.concatenate([group_indices[i] for i in support_set_group_new]) + are_equal = np.all(np.isin(support_set_group, support_set_group_new)) and np.all(np.isin(support_set_group_new, support_set_group)) + if are_equal: + return params, support_new + else: + support_set_group = support_set_group_new + support = support_new + return params, support_new \ No newline at end of file