From c554384c899ebec81dabfe02fa0e9caf3b597ded Mon Sep 17 00:00:00 2001 From: Eli Osherovich Date: Sat, 16 Jan 2021 15:09:51 +0200 Subject: [PATCH] Removed scipy dependency from image_data_generator. --- .../image/image_data_generator.py | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/keras_preprocessing/image/image_data_generator.py b/keras_preprocessing/image/image_data_generator.py index df3fca5b..4f4af202 100644 --- a/keras_preprocessing/image/image_data_generator.py +++ b/keras_preprocessing/image/image_data_generator.py @@ -4,15 +4,6 @@ import numpy as np -try: - import scipy - # scipy.linalg cannot be accessed until explicitly imported - from scipy import linalg - - # scipy.ndimage cannot be accessed until explicitly imported -except ImportError: - scipy = None - from .affine_transformations import (apply_affine_transform, apply_brightness_shift, apply_channel_shift, flip_axis) @@ -315,7 +306,7 @@ def __init__(self, self.mean = None self.std = None - self.principal_components = None + self.zca_whitening_matrix = None if isinstance(zoom_range, (float, int)): self.zoom_range = [1 - zoom_range, 1 + zoom_range] @@ -731,10 +722,10 @@ def standardize(self, x): 'been fit on any training data. Fit it ' 'first by calling `.fit(numpy_data)`.') if self.zca_whitening: - if self.principal_components is not None: - flatx = np.reshape(x, (-1, np.prod(x.shape[-3:]))) - whitex = np.dot(flatx, self.principal_components) - x = np.reshape(whitex, x.shape) + if self.zca_whitening_matrix is not None: + flat_x = x.reshape(-1, np.prod(x.shape[-3:])) + white_x = flat_x @ self.zca_whitening_matrix + x = np.reshape(white_x, x.shape) else: warnings.warn('This ImageDataGenerator specifies ' '`zca_whitening`, but it hasn\'t ' @@ -977,12 +968,9 @@ def fit(self, x, x /= (self.std + 1e-6) if self.zca_whitening: - if scipy is None: - raise ImportError('Using zca_whitening requires SciPy. ' - 'Install SciPy.') - flat_x = np.reshape( - x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] - u, s, _ = linalg.svd(sigma) - s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) - self.principal_components = (u * s_inv).dot(u.T) + n = len(x) + flat_x = np.reshape(x, (n, -1)) + + u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False) + s_inv = np.sqrt(n) / (s + self.zca_epsilon) + self.zca_whitening_matrix = (u * s_inv).dot(u.T)