Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lzj1769 committed Aug 20, 2021
1 parent b46a846 commit 9de86fc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
17 changes: 5 additions & 12 deletions scopen/MF.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,18 @@ def _compute_regularization(alpha, l1_ratio, regularization):
return l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H


def _loss(X, W, H, square_root=False):
def _loss(X, W, H):
"""Compute the Frobenius *squared* norm of X - dot(W, H).
Parameters
----------
X : float or array-like, shape (n_samples, n_features)
Numpy masked arrays or arrays containing NaN are accepted.
W : float or dense array-like, shape (n_samples, n_components)
H : float or dense array-like, shape (n_components, n_features)
square_root : boolean, default False
If True, return np.sqrt(2 * res)
For beta == 2, it corresponds to the Frobenius norm.
Returns
-------
res : float
Beta divergence of X and np.dot(X, H)
Frobenius norm of X and np.dot(X, H)
"""
# The method can be called with scalars
if not sp.issparse(X):
Expand All @@ -232,10 +229,7 @@ def _loss(X, W, H, square_root=False):

assert not np.isnan(res)
assert res >= 0
if square_root:
return np.sqrt(res * 2)
else:
return res
return np.sqrt(res * 2)


def _initialize_nmf(X, n_components, init=None, eps=1e-6,
Expand Down Expand Up @@ -522,7 +516,7 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
f"violation: {_violation: .8f}")

elif verbose == 2:
err = _loss(X, W, Ht.T, square_root=True)
err = _loss(X, W, Ht.T)
print(f"{datetime.now().strftime('%m/%d/%Y %H:%M:%S')}, iteration: {n_iter: }, "
f"violation: {_violation: .8f}, error: {err: .8f}")

Expand Down Expand Up @@ -688,7 +682,6 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
W, H, n_iter = _fit_coordinate_descent(X, W, H, tol, max_iter,
l1_reg_W, l1_reg_H,
l2_reg_W, l2_reg_H,
update_H=True,
verbose=verbose,
shuffle=shuffle,
random_state=random_state)
Expand Down Expand Up @@ -876,7 +869,7 @@ def fit_transform(self, X, y=None, W=None, H=None):
random_state=self.random_state, verbose=self.verbose,
shuffle=self.shuffle)

self.reconstruction_err_ = _loss(X, W, H, square_root=True)
self.reconstruction_err_ = _loss(X, W, H)

self.n_components_ = H.shape[0]
self.components_ = H
Expand Down
6 changes: 4 additions & 2 deletions scopen/Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def estimate_rank(data, args):
for n_components in n_components_list:
arguments = (data, n_components, args.alpha,
args.max_iter, args.verbose,
args.random_state. args.init)
args.random_state, args.init)

res = run_nmf(arguments)
w_hat_dict[n_components] = res[0]
Expand All @@ -139,7 +139,9 @@ def estimate_rank(data, args):
elif args.nc > 1:
arguments_list = list()
for n_components in n_components_list:
arguments = (data, n_components, args.alpha, args.max_iter, args.verbose, args.random_state, args.init)
arguments = (data, n_components, args.alpha,
args.max_iter, args.verbose,
args.random_state, args.init)
arguments_list.append(arguments)

with Pool(processes=args.nc) as pool:
Expand Down

0 comments on commit 9de86fc

Please sign in to comment.