diff --git a/.travis.yml b/.travis.yml index de8d5b15..963135b1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,7 +27,12 @@ matrix: env: TESTS=documentation FLOATX=float64 before_install: - # Setup Python environment with BLAS libraries - - wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh + - | + if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then + wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh + else + wget -q http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh + fi - chmod +x miniconda.sh - ./miniconda.sh -b -p $HOME/miniconda - export PATH=$HOME/miniconda/bin:$PATH diff --git a/blocks/algorithms/__init__.py b/blocks/algorithms/__init__.py index e7b40674..2a159d39 100644 --- a/blocks/algorithms/__init__.py +++ b/blocks/algorithms/__init__.py @@ -696,19 +696,21 @@ class StepClipping(StepRule): """ def __init__(self, threshold=None): - if threshold: - self.threshold = shared_floatx(threshold, "threshold") - add_role(self.threshold, ALGORITHM_HYPERPARAMETER) + if threshold is not None: + threshold = shared_floatx(threshold, "threshold") + add_role(threshold, ALGORITHM_HYPERPARAMETER) + self.threshold = threshold def compute_steps(self, previous_steps): - if not hasattr(self, 'threshold'): - return previous_steps - norm = l2_norm(previous_steps.values()) - multiplier = tensor.switch(norm < self.threshold, - 1, self.threshold / norm) - steps = OrderedDict( - (parameter, step * multiplier) - for parameter, step in previous_steps.items()) + if self.threshold is None: + steps = previous_steps + else: + norm = l2_norm(previous_steps.values()) + multiplier = tensor.switch(norm < self.threshold, + 1, self.threshold / norm) + steps = OrderedDict( + (parameter, step * multiplier) + for parameter, step in previous_steps.items()) return steps, [] diff --git a/blocks/bricks/base.py b/blocks/bricks/base.py index d96ab7fe..4f830627 100644 --- a/blocks/bricks/base.py +++ b/blocks/bricks/base.py @@ -1,4 +1,5 @@ import inspect +import warnings from abc import ABCMeta from collections import OrderedDict from six import wraps @@ -287,11 +288,11 @@ def apply(self, bound_application, *args, **kwargs): last_brick = self.call_stack[-1] if self.call_stack else None if (last_brick and brick is not last_brick and brick not in last_brick.children): - raise ValueError('Brick ' + str(self.call_stack[-1]) + ' tries ' - 'to call brick ' + str(self.brick) + ' which ' - 'is not in the list of its children. This could ' - 'be caused because an @application decorator is ' - 'missing.') + warnings.warn('Brick ' + str(self.call_stack[-1]) + ' tries ' + 'to call brick ' + str(self.brick) + ' which ' + 'is not in the list of its children. This could ' + 'be caused because an @application decorator is ' + 'missing.') self.call_stack.append(brick) try: outputs = self.application_function(brick, *args, **kwargs) diff --git a/blocks/bricks/conv.py b/blocks/bricks/conv.py index 1e8e09f1..39496759 100644 --- a/blocks/bricks/conv.py +++ b/blocks/bricks/conv.py @@ -441,9 +441,9 @@ class ConvolutionalSequence(Sequence, Initializable, Feedforward): ---------- layers : list List of convolutional bricks (i.e. :class:`Convolutional`, - :class:`ConvolutionalActivation`, or :class:`Pooling` bricks). - :class:`Activation` bricks that operate elementwise can also - be included. + :class:`ConvolutionalActivation`, or :class:`Pooling` bricks), + or application methods from such bricks. :class:`Activation` + bricks that operate elementwise can also be included. num_channels : int Number of input channels in the image. For the first layer this is normally 1 for grayscale images and 3 for color (RGB) images. For @@ -494,16 +494,15 @@ class ConvolutionalSequence(Sequence, Initializable, Feedforward): def __init__(self, layers, num_channels, batch_size=None, image_size=(None, None), border_mode=None, tied_biases=None, **kwargs): - self.layers = layers + self.layers = [a if isinstance(a, Brick) else a.brick for a in layers] self.image_size = image_size self.num_channels = num_channels self.batch_size = batch_size self.border_mode = border_mode self.tied_biases = tied_biases - application_methods = [brick.apply for brick in layers] super(ConvolutionalSequence, self).__init__( - application_methods=application_methods, **kwargs) + application_methods=layers, **kwargs) def get_dim(self, name): if name == 'input_': diff --git a/blocks/bricks/sequences.py b/blocks/bricks/sequences.py index 720fd95c..24a78896 100644 --- a/blocks/bricks/sequences.py +++ b/blocks/bricks/sequences.py @@ -1,6 +1,6 @@ """Bricks that compose together other bricks in linear sequences.""" import copy -from toolz import interleave +from toolz import interleave, unique from picklable_itertools.extras import equizip from ..utils import pack @@ -18,16 +18,15 @@ class Sequence(Brick): Parameters ---------- application_methods : list - List of :class:`.BoundApplication` to apply + List of :class:`.BoundApplication` or :class:`.Brick` to apply. + For :class:`.Brick`s, the ``.apply`` method is used. """ def __init__(self, application_methods, **kwargs): - self.application_methods = application_methods - - seen = set() - children = [app.brick for app in application_methods - if not (app.brick in seen or seen.add(app.brick))] - kwargs.setdefault('children', []).extend(children) + pairs = ((a.apply, a) if isinstance(a, Brick) else (a, a.brick) + for a in application_methods) + self.application_methods, bricks = zip(*pairs) + kwargs.setdefault('children', []).extend(unique(bricks)) super(Sequence, self).__init__(**kwargs) @application @@ -123,19 +122,13 @@ def __init__(self, activations, dims, prototype=None, **kwargs): name = self.prototype.__class__.__name__.lower() linear.name = '{}_{}'.format(name, i) self.linear_transformations.append(linear) - # Interleave the transformations and activations - application_methods = [] - for entity in interleave([self.linear_transformations, activations]): - if entity is None: - continue - if isinstance(entity, Brick): - application_methods.append(entity.apply) - else: - application_methods.append(entity) if not dims: dims = [None] * (len(activations) + 1) self.dims = dims - super(MLP, self).__init__(application_methods, **kwargs) + # Interleave the transformations and activations + applications = [a for a in interleave([self.linear_transformations, + activations]) if a is not None] + super(MLP, self).__init__(applications, **kwargs) @property def input_dim(self): diff --git a/blocks/config.py b/blocks/config.py index d5e38ba0..db138b3e 100644 --- a/blocks/config.py +++ b/blocks/config.py @@ -60,7 +60,7 @@ The maximum size of an object to store in an SQLite database in bytes. Objects beyond this size will trigger a warning. Defaults to 4 kilobyte. -.. option:: temp_dir +.. option:: temp_dir, BLOCKS_TEMPDIR The directory in which Blocks will create temporary files. If unspecified, the platform-dependent default chosen by the Python @@ -183,5 +183,6 @@ def str_or_none(val): default=os.path.expanduser('~/blocks_log.sqlite'), env_var='BLOCKS_SQLITEDB') config.add_config('max_blob_size', type_=int, default=4096) -config.add_config('temp_dir', type_=str_or_none, default=None) +config.add_config('temp_dir', type_=str_or_none, default=None, + env_var='BLOCKS_TEMPDIR') config.load_yaml() diff --git a/blocks/extensions/__init__.py b/blocks/extensions/__init__.py index f621f6e7..92ee1c2f 100644 --- a/blocks/extensions/__init__.py +++ b/blocks/extensions/__init__.py @@ -1,5 +1,6 @@ from __future__ import print_function +import datetime import logging from abc import ABCMeta, abstractmethod @@ -218,7 +219,7 @@ class SimpleExtension(TrainingExtension): """ BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch", "before_epoch", "before_batch", - "on_resumption", "on_interrupt", + "on_resumption", "on_interrupt", "on_error", "after_epoch", "after_batch", "after_training"]) @@ -655,3 +656,44 @@ def do(self, which_callback, *args): total_time = self.prefix + 'time_{}_total' current_row[total_time.format(action)] = \ self.current[level][action] + + +class Timestamp(SimpleExtension): + """Adds a human readable (ISO 8601) timestamp to the log. + + Parameters + ---------- + log_record : str, optional + The record name to use. Defaults to 'timestamp'. + separator : str, optional + Separator between the date and time. ISO 8601 specifies 'T'. + Here, we default to ' ' (blank space) for human readability. + + Notes + ----- + By default, triggers after every epoch as well as before training + starts, after training finishes, when an error occurs or when training + is interrupted or resumed, as these are all generally useful + circumstances for which to have a timestamp. These can be disabled + by passing `False` as the appropriate keyword argument; see + :class:`SimpleExtension`. + + """ + DEFAULT_LOG_RECORD = 'timestamp' + + def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ', + **kwargs): + self.log_record = log_record + self.separator = separator + default_callbacks = ['before_training', 'after_epoch', 'on_error', + 'on_interrupt', 'on_resumption', 'after_training'] + for callback in default_callbacks: + kwargs.setdefault(callback, True) + super(Timestamp, self).__init__(**kwargs) + + def do(self, *args): + self.main_loop.log.current_row[self.log_record] = self.get_timestamp() + + def get_timestamp(self): + # Separated into a method to override for ease of testing. + return datetime.datetime.now().isoformat(self.separator) diff --git a/blocks/initialization.py b/blocks/initialization.py index 2ed9d017..d5ce3c4f 100644 --- a/blocks/initialization.py +++ b/blocks/initialization.py @@ -6,7 +6,7 @@ import theano from six import add_metaclass -from blocks.utils import repr_attrs +from blocks.utils import repr_attrs, pack @add_metaclass(ABCMeta) @@ -255,3 +255,39 @@ def generate(self, rng, shape): replace=False) weights[i, random_indices] = values[i] return weights + + +class SparseND(Sparse): + """Initialize only a fraction of the weights with configurable axes. + + Parameters + ---------- + axis : int or sequence + Which axis or axes are to be treated as a "unit" for the purpose + of the number of elements initialized. For example, an axis of + (0, 1) when initializing a 4D tensor `W` will treat the first two + axes of the weight tensor as a grid and initialize `num_init` + elements of `W[0, 0, :, :]`, another `num_init` elements of + `W[0, 1, :, :]`, and so on. + + Notes + ----- + See :class:`Sparse` for documentation of other arguments. + + """ + def __init__(self, axis, **kwargs): + self.axis = axis + super(SparseND, self).__init__(**kwargs) + + def generate(self, rng, shape): + axis_ind = pack(self.axis) + other_ind = [i for i in range(len(shape)) if i not in axis_ind] + axis_shapes = [shape[i] for i in axis_ind] + other_shapes = [shape[i] for i in other_ind] + matrix = super(SparseND, self).generate(rng, + (numpy.prod(axis_shapes), + numpy.prod(other_shapes))) + unflattened = matrix.reshape(tuple(axis_shapes) + tuple(other_shapes)) + wrong_ind = axis_ind + other_ind + transp_ind = [wrong_ind.index(i) for i in range(len(shape))] + return unflattened.transpose(transp_ind) diff --git a/blocks/monitoring/aggregation.py b/blocks/monitoring/aggregation.py index 558ff5c8..244b57a7 100644 --- a/blocks/monitoring/aggregation.py +++ b/blocks/monitoring/aggregation.py @@ -1,4 +1,5 @@ """Evaluate Theano variables on auxiliary data and during training.""" +from functools import partial import logging from abc import ABCMeta, abstractmethod @@ -29,6 +30,9 @@ class AggregationScheme(object): The variable that holds the desired value on a single batch. """ + def __init__(self, variable): + self.variable = variable + @abstractmethod def get_aggregator(self): """Return a new Aggregator for this variable.""" @@ -149,9 +153,6 @@ def mean(numerator, denominator=1.): class _DataIndependent(AggregationScheme): """Dummy aggregation scheme for values that don't depend on data.""" - def __init__(self, variable): - self.variable = variable - def get_aggregator(self): return Aggregator(aggregation_scheme=self, initialization_updates=[], @@ -161,9 +162,6 @@ def get_aggregator(self): class TakeLast(AggregationScheme): """Aggregation scheme which remembers only the last value.""" - def __init__(self, variable): - self.variable = variable - def get_aggregator(self): self.storage = shared_like(self.variable) return Aggregator(aggregation_scheme=self, @@ -173,12 +171,73 @@ def get_aggregator(self): readout_variable=self.storage) -def take_last(variable): +def _simple_aggregation(scheme, variable): variable = variable.copy(variable.name) - variable.tag.aggregation_scheme = TakeLast(variable) + variable.tag.aggregation_scheme = scheme(variable) return variable +take_last = partial(_simple_aggregation, TakeLast) + + +class Minimum(AggregationScheme): + """Aggregation scheme which remembers only the minimum value.""" + def _build_aggregator(self, accumulate_update): + initialized = shared_like(0.) + accumulate = ifelse(initialized, accumulate_update, self.variable) + return Aggregator(aggregation_scheme=self, + initialization_updates=[ + (self.storage, tensor.zeros_like(self.storage)), + (initialized, tensor.zeros_like(initialized)) + ], + accumulation_updates=[ + (self.storage, accumulate), + (initialized, tensor.ones_like(initialized)) + ], + readout_variable=self.storage) + + def get_aggregator(self): + self.storage = shared_like(self.variable) + return self._build_aggregator(tensor.minimum(self.storage, + self.variable)) + +minimum = partial(_simple_aggregation, Minimum) + + +class Maximum(Minimum): + """Aggregation scheme which remembers only the maximum value.""" + def get_aggregator(self): + self.storage = shared_like(self.variable) + return self._build_aggregator(tensor.maximum(self.storage, + self.variable)) + +maximum = partial(_simple_aggregation, Maximum) + + +class Concatenate(Minimum): + """Aggregation scheme which remembers values from all batches. + + Parameters + ---------- + variable: :class:`~tensor.TensorVariable` + The variable that holds the desired value on a single batch. + + """ + def __init__(self, variable): + # Add an extra axis to concatenate along. Must be non-broadcastable + # for concatenate to always work. + variable = (tensor.unbroadcast(tensor.shape_padleft(variable, 1), 0) + .copy(variable.name)) + super(Concatenate, self).__init__(variable) + + def get_aggregator(self): + self.storage = shared_like(self.variable) + return self._build_aggregator(tensor.concatenate([self.storage, + self.variable])) + +concatenate = partial(_simple_aggregation, Concatenate) + + @add_metaclass(ABCMeta) class MonitoredQuantity(object): """The base class for monitored-quantities. diff --git a/blocks/serialization.py b/blocks/serialization.py index 1de376b3..149d5789 100644 --- a/blocks/serialization.py +++ b/blocks/serialization.py @@ -561,7 +561,7 @@ def _recreate_cuda_ndarray(_, content): def _recreate_pygpu_array(context_name, content): - context = theano.sandbox.gpuarray.get_context(context_name) + context = theano.gpuarray.get_context(context_name) return pygpu.gpuarray.array(content, context=context) _ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'} diff --git a/tests/algorithms/test_algorithms.py b/tests/algorithms/test_algorithms.py index 7c21107b..3764a370 100644 --- a/tests/algorithms/test_algorithms.py +++ b/tests/algorithms/test_algorithms.py @@ -311,6 +311,16 @@ def test_step_clipping(): assert_allclose(clipped2[1].eval(), 4.0) +def test_step_clipping_no_threshold_regression(): + """Test regression for #1145, incorrect output when threshold=None.""" + rule1 = StepClipping() + assert rule1.threshold is None + gradients = {0: shared_floatx(3.0), 1: shared_floatx(4.0)} + clipped1, updates = rule1.compute_steps(gradients) + assert len(updates) == 0 + assert clipped1 == gradients + + def test_step_clipping_broadcastable(): verify_broadcastable_handling(StepClipping(0.4)) diff --git a/tests/bricks/test_bricks.py b/tests/bricks/test_bricks.py index b2b9ae50..adef36f4 100644 --- a/tests/bricks/test_bricks.py +++ b/tests/bricks/test_bricks.py @@ -427,14 +427,34 @@ def test_sequence(): linear_2 = Linear(input_dim=8, output_dim=4, weights_init=Constant(3), biases_init=Constant(4)) - sequence = Sequence([linear_1.apply, linear_2.apply]) - sequence.initialize() - y = sequence.apply(x) - x_val = numpy.ones((4, 16), dtype=theano.config.floatX) - assert_allclose( - y.eval({x: x_val}), - (x_val.dot(2 * numpy.ones((16, 8))) + numpy.ones((4, 8))).dot( - 3 * numpy.ones((8, 4))) + 4 * numpy.ones((4, 4))) + + def check(bricks): + sequence = Sequence(bricks) + sequence.initialize() + y = sequence.apply(x) + x_val = numpy.ones((4, 16), dtype=theano.config.floatX) + assert_allclose( + y.eval({x: x_val}), + (x_val.dot(2 * numpy.ones((16, 8))) + numpy.ones((4, 8))).dot( + 3 * numpy.ones((8, 4))) + 4 * numpy.ones((4, 4))) + + # Test with all application methods. + yield check, [linear_1.apply, linear_2.apply] + + # Test with all bricks. + yield check, [linear_1, linear_2] + + # Test with a mix of bricks and application methods. + yield check, [linear_1, linear_2.apply] + yield check, [linear_1.apply, linear_2] + + # Test with an application method not called 'apply'. + class Dummy(Brick): + @application + def foobar(self, input_): + return input_ + + yield check, [linear_1.apply, linear_2.apply, Dummy().foobar] def test_sequence_variable_outputs(): diff --git a/tests/bricks/test_conv.py b/tests/bricks/test_conv.py index a4e425e2..676ad865 100644 --- a/tests/bricks/test_conv.py +++ b/tests/bricks/test_conv.py @@ -392,8 +392,8 @@ def test_convolutional_sequence(): pooling = MaxPooling(pooling_size=(pooling_size, pooling_size)) conv2 = Convolutional((2, 2), 4, weights_init=Constant(1.)) - seq = ConvolutionalSequence([conv, act, pooling, conv2, act], num_channels, - image_size=(17, 13)) + seq = ConvolutionalSequence([conv, act, pooling.apply, conv2.apply, act], + num_channels, image_size=(17, 13)) seq.push_allocation_config() assert conv.num_channels == 4 assert conv2.num_channels == 5 diff --git a/tests/extensions/test_extensions.py b/tests/extensions/test_extensions.py index f6b7f247..c9c2eafe 100644 --- a/tests/extensions/test_extensions.py +++ b/tests/extensions/test_extensions.py @@ -1,7 +1,8 @@ +import re from mock import Mock from numpy.testing import assert_raises -from blocks.extensions import SimpleExtension, CompositeExtension +from blocks.extensions import SimpleExtension, CompositeExtension, Timestamp from blocks.extensions.saveload import Checkpoint from blocks.extensions.predicates import OnLogRecord @@ -146,3 +147,56 @@ def do(self, which_callback, *args): ext.main_loop = Mock() ext.dispatch('before_batch') ext.do.assert_called_once_with('before_batch') + + +class InjectedTimestamp(Timestamp): + def __init__(self, **kwargs): + self.returns = ['foo', 'bar', 'baz'] + super(InjectedTimestamp, self).__init__(**kwargs) + + def get_timestamp(self): + if len(self.returns) > 0: + return self.returns.pop() + return super(InjectedTimestamp, self).get_timestamp() + + +def test_timestamp(): + def check(kwargs): + if 'log_record' in kwargs: + log_record = kwargs['log_record'] + else: + log_record = Timestamp.DEFAULT_LOG_RECORD + ext = InjectedTimestamp(**kwargs) + ext.main_loop = Mock() + ext.main_loop.log.current_row = {} + ext.do('after_epoch') + assert ext.main_loop.log.current_row[log_record] == 'baz' + ext.do('after_epoch') + assert ext.main_loop.log.current_row[log_record] == 'bar' + ext.do('after_epoch') + assert ext.main_loop.log.current_row[log_record] == 'foo' + # Exercise original get_timestamp. + ext.do('after_epoch') + sep = kwargs.get('separator', ' ') + assert bool(re.match(''.join(['[0-9]{4}-[0-9]{2}-[0-9]{2}', sep, + '[0-9]{2}(\\:[0-9]{2}){2}' + '\\.[0-9]+']), + ext.main_loop.log.current_row[log_record])) + + yield check, {} + yield check, {'log_record': 'loggy mclogpants'} + + +def test_timestamp_default_triggers(): + def check(callback): + ext = InjectedTimestamp() + ext.main_loop = Mock() + ext.main_loop.log.current_row = {} + ext.dispatch(callback) + assert ext.main_loop.log.current_row.get('timestamp') == 'baz' + + callbacks = ['before_training', 'after_epoch', 'on_error', + 'on_interrupt', 'on_resumption', 'after_training'] + + for callback in callbacks: + yield check, callback diff --git a/tests/monitoring/test_aggregation.py b/tests/monitoring/test_aggregation.py index ecbddcf2..1c72c8a1 100644 --- a/tests/monitoring/test_aggregation.py +++ b/tests/monitoring/test_aggregation.py @@ -6,7 +6,8 @@ from blocks import bricks from blocks.bricks.base import application from blocks.graph import ComputationGraph -from blocks.monitoring.aggregation import mean, Mean +from blocks.monitoring.aggregation import (mean, Mean, Minimum, Maximum, + Concatenate) from blocks.utils import shared_floatx from collections import OrderedDict @@ -91,6 +92,78 @@ def test_mean_aggregator(): numpy.array([35], dtype=theano.config.floatX)) +def test_min_max_aggregators(): + num_examples = 4 + batch_size = 2 + + features = numpy.array([[2, 3], + [2, 9], + [2, 4], + [5, 1]], dtype=theano.config.floatX) + + dataset = IndexableDataset(OrderedDict([('features', features)])) + + data_stream = DataStream(dataset, + iteration_scheme=SequentialScheme(num_examples, + batch_size)) + + x = tensor.matrix('features') + y = (x**2).sum(axis=0) + y.name = 'y' + z = y.min() + z.name = 'z' + + y.tag.aggregation_scheme = Maximum(y) + z.tag.aggregation_scheme = Minimum(z) + + assert_allclose(DatasetEvaluator([y]).evaluate(data_stream)['y'], + numpy.array([29, 90], dtype=theano.config.floatX)) + assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'], + numpy.array([8], dtype=theano.config.floatX)) + + # Make sure accumulators are reset. + features = numpy.array([[2, 1], + [1, 3], + [1, -1], + [2.5, 1]], dtype=theano.config.floatX) + + dataset = IndexableDataset(OrderedDict([('features', features)])) + + data_stream = DataStream(dataset, + iteration_scheme=SequentialScheme(num_examples, + batch_size)) + assert_allclose(DatasetEvaluator([y]).evaluate(data_stream)['y'], + numpy.array([7.25, 10], dtype=theano.config.floatX)) + assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'], + numpy.array([2], dtype=theano.config.floatX)) + + +def test_concatenate_aggregator(): + num_examples = 4 + batch_size = 2 + + features = numpy.array([[2, 3], + [2, 9], + [2, 4], + [5, 1]], dtype=theano.config.floatX) + + dataset = IndexableDataset(OrderedDict([('features', features)])) + + data_stream = DataStream(dataset, + iteration_scheme=SequentialScheme(num_examples, + batch_size)) + x = tensor.matrix('features') + y = x.sum(axis=0).copy('y') + z = y.sum(axis=0).copy('z') + y.tag.aggregation_scheme = Concatenate(y) + z.tag.aggregation_scheme = Concatenate(z) + + assert_allclose(DatasetEvaluator([y]).evaluate(data_stream)['y'], + numpy.array([[4, 12], [7, 5]], dtype=theano.config.floatX)) + assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'], + numpy.array([16, 12], dtype=theano.config.floatX)) + + def test_aggregation_buffer_name_uniqueness(): x1 = tensor.scalar('x') x2 = tensor.scalar('x') diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 3c255e15..de374ad8 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -4,8 +4,9 @@ import theano from numpy.testing import assert_equal, assert_allclose, assert_raises -from blocks.initialization import Constant, IsotropicGaussian, Sparse +from blocks.initialization import Constant, IsotropicGaussian, Sparse, SparseND from blocks.initialization import Uniform, Orthogonal, Identity +from blocks.utils import pack def test_constant(): @@ -97,6 +98,34 @@ def check_sparse(rng, num_init, weights_init, sparse_init, shape, total): yield check_sparse, rng, 0.3, Constant(0.), Constant(1.), (10, 10), 70 +def test_sparse_nd(): + rng = numpy.random.RandomState(1) + + def check_sparse(rng, axis, num_init, shape, weights_init=Constant(1.)): + weights = SparseND(axis=axis, num_init=num_init, + weights_init=weights_init).generate(rng, shape) + assert weights.shape == shape + assert weights.dtype == theano.config.floatX + if isinstance(num_init, numbers.Integral): + nnz = numpy.prod([s for i, s in enumerate(shape) + if i in pack(axis)]) * num_init + assert numpy.count_nonzero(weights) == nnz + else: + atom_size = numpy.prod([s for i, s in enumerate(shape) + if i not in pack(axis)]) + nnz_atom = int(num_init * atom_size) + num_atoms = numpy.prod([s for i, s in enumerate(shape) + if i in pack(axis)]) + nnz = nnz_atom * num_atoms + assert numpy.count_nonzero(weights) == nnz + + yield check_sparse, rng, 1, 5, (10, 11) + yield check_sparse, rng, 2, 3, (7, 8, 9) + yield check_sparse, rng, (2, 3), 5. / 6., (2, 3, 5, 7) + yield check_sparse, rng, (0, 1), 3, (3, 5, 7, 11) + yield check_sparse, rng, (0, 2, 3), 0.5, (2, 3, 2, 6) + + def test_orthogonal(): rng = numpy.random.RandomState(1)