Skip to content

Commit

Permalink
ENH refactor NMF and add CD solver
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDLT committed Sep 21, 2015
1 parent b9284a7 commit ceeef70
Show file tree
Hide file tree
Showing 11 changed files with 20,616 additions and 366 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
/sklearn/cluster/_k_means.c -diff
/sklearn/datasets/_svmlight_format.c -diff
/sklearn/decomposition/_online_lda.c -diff
/sklearn/decomposition/cdnmf_fast.c -diff
/sklearn/ensemble/_gradient_boosting.c -diff
/sklearn/feature_extraction/_hashing.c -diff
/sklearn/linear_model/cd_fast.c -diff
Expand Down
24 changes: 9 additions & 15 deletions benchmarks/bench_plot_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sklearn.externals.six.moves import xrange


def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
def alt_nnmf(V, r, max_iter=1000, tol=1e-3, init='random'):
'''
A, S = nnmf(X, r, tol=1e-3, R=None)
Expand All @@ -33,8 +33,8 @@ def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
tol : double
tolerance threshold for early exit (when the update factor is within
tol of 1., the function exits)
R : integer, optional
random seed
init : string
Method used to initialize the procedure.
Returns
-------
Expand All @@ -52,12 +52,7 @@ def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
# Nomenclature in the function follows Lee & Seung
eps = 1e-5
n, m = V.shape
if R == "svd":
W, H = _initialize_nmf(V, r)
elif R is None:
R = np.random.mtrand._rand
W = np.abs(R.standard_normal((n, r)))
H = np.abs(R.standard_normal((r, m)))
W, H = _initialize_nmf(V, r, init, random_state=0)

for i in xrange(max_iter):
updateH = np.dot(W.T, V) / (np.dot(np.dot(W.T, W), H) + eps)
Expand All @@ -78,17 +73,15 @@ def report(error, time):


def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
it = 0
timeset = defaultdict(lambda: [])
err = defaultdict(lambda: [])

max_it = len(samples_range) * len(features_range)
for n_samples in samples_range:
for n_features in features_range:
print("%2d samples, %2d features" % (n_samples, n_features))
print('=======================')
X = np.abs(make_low_rank_matrix(n_samples, n_features,
effective_rank=rank, tail_strength=0.2))
effective_rank=rank, tail_strength=0.2))

gc.collect()
print("benchmarking nndsvd-nmf: ")
Expand Down Expand Up @@ -122,7 +115,7 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
gc.collect()
print("benchmarking random-nmf")
tstart = time()
m = NMF(n_components=30, init=None, max_iter=1000,
m = NMF(n_components=30, init='random', max_iter=1000,
tol=tolerance).fit(X)
tend = time() - tstart
timeset['random-nmf'].append(tend)
Expand All @@ -132,7 +125,7 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
gc.collect()
print("benchmarking alt-random-nmf")
tstart = time()
W, H = alt_nnmf(X, r=30, R=None, tol=tolerance)
W, H = alt_nnmf(X, r=30, init='random', tol=tolerance)
tend = time() - tstart
timeset['alt-random-nmf'].append(tend)
err['alt-random-nmf'].append(np.linalg.norm(X - np.dot(W, H)))
Expand All @@ -151,7 +144,8 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
timeset, err = benchmark(samples_range, features_range)

for i, results in enumerate((timeset, err)):
fig = plt.figure('scikit-learn Non-Negative Matrix Factorization benchmark results')
fig = plt.figure('scikit-learn Non-Negative Matrix Factorization'
'benchmark results')
ax = fig.gca(projection='3d')
for c, (label, timings) in zip('rbgcm', sorted(results.iteritems())):
X, Y = np.meshgrid(samples_range, features_range)
Expand Down
48 changes: 37 additions & 11 deletions doc/modules/decomposition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -656,12 +656,12 @@ into two matrices :math:`W` and :math:`H` of non-negative elements,
by optimizing for the squared Frobenius norm:

.. math::
\arg\min_{W,H} ||X - WH||^2 = \sum_{i,j} X_{ij} - {WH}_{ij}
\arg\min_{W,H} \frac{1}{2} ||X - WH||_{Fro}^2 = \frac{1}{2} \sum_{i,j} (X_{ij} - {WH}_{ij})^2
This norm is an obvious extension of the Euclidean norm to matrices.
(Other optimization objectives have been suggested in the NMF literature,
in particular Kullback-Leibler divergence,
but these are not currently implemented.)
This norm is an obvious extension of the Euclidean norm to matrices. (Other
optimization objectives have been suggested in the NMF literature, in
particular Kullback-Leibler divergence, but these are not currently
implemented.)

Unlike :class:`PCA`, the representation of a vector is obtained in an additive
fashion, by superimposing the components, without subtracting. Such additive
Expand Down Expand Up @@ -695,13 +695,34 @@ the mean of all elements of the data), and NNDSVDar (in which the zeros are set
to random perturbations less than the mean of the data divided by 100) are
recommended in the dense case.

:class:`NMF` can also be initialized with random non-negative matrices, by
passing an integer seed or a ``RandomState`` to :attr:`init`.
:class:`NMF` can also be initialized with correctly scaled random non-negative
matrices by setting :attr:`init="random"`. An integer seed or a
``RandomState`` can also be passed to :attr:`random_state` to control
reproducibility.

In :class:`NMF`, sparseness can be enforced by setting the attribute
:attr:`sparseness` to ``"data"`` or ``"components"``. Sparse components lead to
localized features, and sparse data leads to a more efficient representation of
the data.
In :class:`NMF`, L1 and L2 priors can be added to the loss function in order
to regularize the model. The L2 prior uses the Frobenius norm, while the L1
prior uses an elementwise L1 norm. As in :class:`ElasticNet`, we control the
combination of L1 and L2 with the :attr:`l1_ratio` (:math:`\rho`) parameter,
and the intensity of the regularization with the :attr:`alpha`
(:math:`\alpha`) parameter. Then the priors terms are:

.. math::
\alpha \rho ||W||_1 + \alpha \rho ||H||_1
+ \frac{\alpha(1-\rho)}{2} ||W||_{Fro} ^ 2
+ \frac{\alpha(1-\rho)}{2} ||H||_{Fro} ^ 2
and the regularized objective function is:

.. math::
\frac{1}{2}||X - WH||_{Fro}^2
+ \alpha \rho ||W||_1 + \alpha \rho ||H||_1
+ \frac{\alpha(1-\rho)}{2} ||W||_{Fro} ^ 2
+ \frac{\alpha(1-\rho)}{2} ||H||_{Fro} ^ 2
:class:`NMF` regularizes both W and H. The public function
:func:`non_negative_factorization` allows a finer control through the
:attr:`regularization` attribute, and may regularize only W, only H, or both.

.. topic:: Examples:

Expand All @@ -727,6 +748,11 @@ the data.
<http://scgroup.hpclab.ceid.upatras.gr/faculty/stratis/Papers/HPCLAB020107.pdf>`_
C. Boutsidis, E. Gallopoulos, 2008

* `"Fast local algorithms for large scale nonnegative matrix and tensor
factorizations."
<http://www.bsp.brain.riken.jp/publications/2009/Cichocki-Phan-IEICE_col.pdf>`_
A. Cichocki, P. Anh-Huy, 2009


.. _LatentDirichletAllocation:

Expand Down
11 changes: 10 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ New features
datasets. By `Danny Sullivan`_ and `Tom Dupre la Tour`_.
(`#4738 <https://github.com/scikit-learn/scikit-learn/pull/4738>`_)

- The new solver ``cd`` implements a Coordinate Descent in
:class:`decomposition.NMF`. Previous solver based on Projected Gradient is
still available setting new parameter ``solver`` to ``pg``, but is
deprecated and will be removed in 0.19, along with
:class:`decompositionProjectedGradientNMF` and parameters``sparseness``,
``eta``, ``beta`` and ``nls_max_iter``. New parameters ``alpha`` and
``l1_ratio`` control L1 and L2 regularizations, and ``shuffle`` adds a
shuffling step in ``cd`` solver.
By `Tom Dupre la Tour`_ and `Mathieu Blondel`_.

Enhancements
............
- :class:`manifold.TSNE` now supports approximate optimization via the
Expand Down Expand Up @@ -192,7 +202,6 @@ Enhancements
- Added ``sample_weight`` support to :class:`linear_model.LogisticRegression` for
the ``lbfgs``, ``newton-cg``, and ``sag`` solvers. By `Valentin Stolbunov`_.


Bug fixes
.........

Expand Down
3 changes: 1 addition & 2 deletions examples/decomposition/plot_faces_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def plot_gallery(title, images, n_col=n_col, n_row=n_row):
True),

('Non-negative components - NMF',
decomposition.NMF(n_components=n_components, init='nndsvda', beta=5.0,
tol=5e-3, sparseness='components'),
decomposition.NMF(n_components=n_components, init='nndsvda', tol=5e-3),
False),

('Independent components - FastICA',
Expand Down
Loading

0 comments on commit ceeef70

Please sign in to comment.