Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
SwordYork committed Jan 3, 2017
2 parents ec4192e + 3db89c0 commit de941ce
Show file tree
Hide file tree
Showing 16 changed files with 391 additions and 67 deletions.
7 changes: 6 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ matrix:
env: TESTS=documentation FLOATX=float64
before_install:
- # Setup Python environment with BLAS libraries
- wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
- |
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then
wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
else
wget -q http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
fi
- chmod +x miniconda.sh
- ./miniconda.sh -b -p $HOME/miniconda
- export PATH=$HOME/miniconda/bin:$PATH
Expand Down
24 changes: 13 additions & 11 deletions blocks/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,19 +696,21 @@ class StepClipping(StepRule):
"""
def __init__(self, threshold=None):
if threshold:
self.threshold = shared_floatx(threshold, "threshold")
add_role(self.threshold, ALGORITHM_HYPERPARAMETER)
if threshold is not None:
threshold = shared_floatx(threshold, "threshold")
add_role(threshold, ALGORITHM_HYPERPARAMETER)
self.threshold = threshold

def compute_steps(self, previous_steps):
if not hasattr(self, 'threshold'):
return previous_steps
norm = l2_norm(previous_steps.values())
multiplier = tensor.switch(norm < self.threshold,
1, self.threshold / norm)
steps = OrderedDict(
(parameter, step * multiplier)
for parameter, step in previous_steps.items())
if self.threshold is None:
steps = previous_steps
else:
norm = l2_norm(previous_steps.values())
multiplier = tensor.switch(norm < self.threshold,
1, self.threshold / norm)
steps = OrderedDict(
(parameter, step * multiplier)
for parameter, step in previous_steps.items())
return steps, []


Expand Down
11 changes: 6 additions & 5 deletions blocks/bricks/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import warnings
from abc import ABCMeta
from collections import OrderedDict
from six import wraps
Expand Down Expand Up @@ -287,11 +288,11 @@ def apply(self, bound_application, *args, **kwargs):
last_brick = self.call_stack[-1] if self.call_stack else None
if (last_brick and brick is not last_brick and
brick not in last_brick.children):
raise ValueError('Brick ' + str(self.call_stack[-1]) + ' tries '
'to call brick ' + str(self.brick) + ' which '
'is not in the list of its children. This could '
'be caused because an @application decorator is '
'missing.')
warnings.warn('Brick ' + str(self.call_stack[-1]) + ' tries '
'to call brick ' + str(self.brick) + ' which '
'is not in the list of its children. This could '
'be caused because an @application decorator is '
'missing.')
self.call_stack.append(brick)
try:
outputs = self.application_function(brick, *args, **kwargs)
Expand Down
11 changes: 5 additions & 6 deletions blocks/bricks/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ class ConvolutionalSequence(Sequence, Initializable, Feedforward):
----------
layers : list
List of convolutional bricks (i.e. :class:`Convolutional`,
:class:`ConvolutionalActivation`, or :class:`Pooling` bricks).
:class:`Activation` bricks that operate elementwise can also
be included.
:class:`ConvolutionalActivation`, or :class:`Pooling` bricks),
or application methods from such bricks. :class:`Activation`
bricks that operate elementwise can also be included.
num_channels : int
Number of input channels in the image. For the first layer this is
normally 1 for grayscale images and 3 for color (RGB) images. For
Expand Down Expand Up @@ -494,16 +494,15 @@ class ConvolutionalSequence(Sequence, Initializable, Feedforward):
def __init__(self, layers, num_channels, batch_size=None,
image_size=(None, None), border_mode=None, tied_biases=None,
**kwargs):
self.layers = layers
self.layers = [a if isinstance(a, Brick) else a.brick for a in layers]
self.image_size = image_size
self.num_channels = num_channels
self.batch_size = batch_size
self.border_mode = border_mode
self.tied_biases = tied_biases

application_methods = [brick.apply for brick in layers]
super(ConvolutionalSequence, self).__init__(
application_methods=application_methods, **kwargs)
application_methods=layers, **kwargs)

def get_dim(self, name):
if name == 'input_':
Expand Down
29 changes: 11 additions & 18 deletions blocks/bricks/sequences.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Bricks that compose together other bricks in linear sequences."""
import copy
from toolz import interleave
from toolz import interleave, unique
from picklable_itertools.extras import equizip

from ..utils import pack
Expand All @@ -18,16 +18,15 @@ class Sequence(Brick):
Parameters
----------
application_methods : list
List of :class:`.BoundApplication` to apply
List of :class:`.BoundApplication` or :class:`.Brick` to apply.
For :class:`.Brick`s, the ``.apply`` method is used.
"""
def __init__(self, application_methods, **kwargs):
self.application_methods = application_methods

seen = set()
children = [app.brick for app in application_methods
if not (app.brick in seen or seen.add(app.brick))]
kwargs.setdefault('children', []).extend(children)
pairs = ((a.apply, a) if isinstance(a, Brick) else (a, a.brick)
for a in application_methods)
self.application_methods, bricks = zip(*pairs)
kwargs.setdefault('children', []).extend(unique(bricks))
super(Sequence, self).__init__(**kwargs)

@application
Expand Down Expand Up @@ -123,19 +122,13 @@ def __init__(self, activations, dims, prototype=None, **kwargs):
name = self.prototype.__class__.__name__.lower()
linear.name = '{}_{}'.format(name, i)
self.linear_transformations.append(linear)
# Interleave the transformations and activations
application_methods = []
for entity in interleave([self.linear_transformations, activations]):
if entity is None:
continue
if isinstance(entity, Brick):
application_methods.append(entity.apply)
else:
application_methods.append(entity)
if not dims:
dims = [None] * (len(activations) + 1)
self.dims = dims
super(MLP, self).__init__(application_methods, **kwargs)
# Interleave the transformations and activations
applications = [a for a in interleave([self.linear_transformations,
activations]) if a is not None]
super(MLP, self).__init__(applications, **kwargs)

@property
def input_dim(self):
Expand Down
5 changes: 3 additions & 2 deletions blocks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
The maximum size of an object to store in an SQLite database in bytes.
Objects beyond this size will trigger a warning. Defaults to 4 kilobyte.
.. option:: temp_dir
.. option:: temp_dir, BLOCKS_TEMPDIR
The directory in which Blocks will create temporary files. If
unspecified, the platform-dependent default chosen by the Python
Expand Down Expand Up @@ -183,5 +183,6 @@ def str_or_none(val):
default=os.path.expanduser('~/blocks_log.sqlite'),
env_var='BLOCKS_SQLITEDB')
config.add_config('max_blob_size', type_=int, default=4096)
config.add_config('temp_dir', type_=str_or_none, default=None)
config.add_config('temp_dir', type_=str_or_none, default=None,
env_var='BLOCKS_TEMPDIR')
config.load_yaml()
44 changes: 43 additions & 1 deletion blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import print_function

import datetime
import logging
from abc import ABCMeta, abstractmethod

Expand Down Expand Up @@ -218,7 +219,7 @@ class SimpleExtension(TrainingExtension):
"""
BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
"before_epoch", "before_batch",
"on_resumption", "on_interrupt",
"on_resumption", "on_interrupt", "on_error",
"after_epoch", "after_batch",
"after_training"])

Expand Down Expand Up @@ -655,3 +656,44 @@ def do(self, which_callback, *args):
total_time = self.prefix + 'time_{}_total'
current_row[total_time.format(action)] = \
self.current[level][action]


class Timestamp(SimpleExtension):
"""Adds a human readable (ISO 8601) timestamp to the log.
Parameters
----------
log_record : str, optional
The record name to use. Defaults to 'timestamp'.
separator : str, optional
Separator between the date and time. ISO 8601 specifies 'T'.
Here, we default to ' ' (blank space) for human readability.
Notes
-----
By default, triggers after every epoch as well as before training
starts, after training finishes, when an error occurs or when training
is interrupted or resumed, as these are all generally useful
circumstances for which to have a timestamp. These can be disabled
by passing `False` as the appropriate keyword argument; see
:class:`SimpleExtension`.
"""
DEFAULT_LOG_RECORD = 'timestamp'

def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ',
**kwargs):
self.log_record = log_record
self.separator = separator
default_callbacks = ['before_training', 'after_epoch', 'on_error',
'on_interrupt', 'on_resumption', 'after_training']
for callback in default_callbacks:
kwargs.setdefault(callback, True)
super(Timestamp, self).__init__(**kwargs)

def do(self, *args):
self.main_loop.log.current_row[self.log_record] = self.get_timestamp()

def get_timestamp(self):
# Separated into a method to override for ease of testing.
return datetime.datetime.now().isoformat(self.separator)
38 changes: 37 additions & 1 deletion blocks/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import theano
from six import add_metaclass

from blocks.utils import repr_attrs
from blocks.utils import repr_attrs, pack


@add_metaclass(ABCMeta)
Expand Down Expand Up @@ -255,3 +255,39 @@ def generate(self, rng, shape):
replace=False)
weights[i, random_indices] = values[i]
return weights


class SparseND(Sparse):
"""Initialize only a fraction of the weights with configurable axes.
Parameters
----------
axis : int or sequence
Which axis or axes are to be treated as a "unit" for the purpose
of the number of elements initialized. For example, an axis of
(0, 1) when initializing a 4D tensor `W` will treat the first two
axes of the weight tensor as a grid and initialize `num_init`
elements of `W[0, 0, :, :]`, another `num_init` elements of
`W[0, 1, :, :]`, and so on.
Notes
-----
See :class:`Sparse` for documentation of other arguments.
"""
def __init__(self, axis, **kwargs):
self.axis = axis
super(SparseND, self).__init__(**kwargs)

def generate(self, rng, shape):
axis_ind = pack(self.axis)
other_ind = [i for i in range(len(shape)) if i not in axis_ind]
axis_shapes = [shape[i] for i in axis_ind]
other_shapes = [shape[i] for i in other_ind]
matrix = super(SparseND, self).generate(rng,
(numpy.prod(axis_shapes),
numpy.prod(other_shapes)))
unflattened = matrix.reshape(tuple(axis_shapes) + tuple(other_shapes))
wrong_ind = axis_ind + other_ind
transp_ind = [wrong_ind.index(i) for i in range(len(shape))]
return unflattened.transpose(transp_ind)
75 changes: 67 additions & 8 deletions blocks/monitoring/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Evaluate Theano variables on auxiliary data and during training."""
from functools import partial
import logging
from abc import ABCMeta, abstractmethod

Expand Down Expand Up @@ -29,6 +30,9 @@ class AggregationScheme(object):
The variable that holds the desired value on a single batch.
"""
def __init__(self, variable):
self.variable = variable

@abstractmethod
def get_aggregator(self):
"""Return a new Aggregator for this variable."""
Expand Down Expand Up @@ -149,9 +153,6 @@ def mean(numerator, denominator=1.):

class _DataIndependent(AggregationScheme):
"""Dummy aggregation scheme for values that don't depend on data."""
def __init__(self, variable):
self.variable = variable

def get_aggregator(self):
return Aggregator(aggregation_scheme=self,
initialization_updates=[],
Expand All @@ -161,9 +162,6 @@ def get_aggregator(self):

class TakeLast(AggregationScheme):
"""Aggregation scheme which remembers only the last value."""
def __init__(self, variable):
self.variable = variable

def get_aggregator(self):
self.storage = shared_like(self.variable)
return Aggregator(aggregation_scheme=self,
Expand All @@ -173,12 +171,73 @@ def get_aggregator(self):
readout_variable=self.storage)


def take_last(variable):
def _simple_aggregation(scheme, variable):
variable = variable.copy(variable.name)
variable.tag.aggregation_scheme = TakeLast(variable)
variable.tag.aggregation_scheme = scheme(variable)
return variable


take_last = partial(_simple_aggregation, TakeLast)


class Minimum(AggregationScheme):
"""Aggregation scheme which remembers only the minimum value."""
def _build_aggregator(self, accumulate_update):
initialized = shared_like(0.)
accumulate = ifelse(initialized, accumulate_update, self.variable)
return Aggregator(aggregation_scheme=self,
initialization_updates=[
(self.storage, tensor.zeros_like(self.storage)),
(initialized, tensor.zeros_like(initialized))
],
accumulation_updates=[
(self.storage, accumulate),
(initialized, tensor.ones_like(initialized))
],
readout_variable=self.storage)

def get_aggregator(self):
self.storage = shared_like(self.variable)
return self._build_aggregator(tensor.minimum(self.storage,
self.variable))

minimum = partial(_simple_aggregation, Minimum)


class Maximum(Minimum):
"""Aggregation scheme which remembers only the maximum value."""
def get_aggregator(self):
self.storage = shared_like(self.variable)
return self._build_aggregator(tensor.maximum(self.storage,
self.variable))

maximum = partial(_simple_aggregation, Maximum)


class Concatenate(Minimum):
"""Aggregation scheme which remembers values from all batches.
Parameters
----------
variable: :class:`~tensor.TensorVariable`
The variable that holds the desired value on a single batch.
"""
def __init__(self, variable):
# Add an extra axis to concatenate along. Must be non-broadcastable
# for concatenate to always work.
variable = (tensor.unbroadcast(tensor.shape_padleft(variable, 1), 0)
.copy(variable.name))
super(Concatenate, self).__init__(variable)

def get_aggregator(self):
self.storage = shared_like(self.variable)
return self._build_aggregator(tensor.concatenate([self.storage,
self.variable]))

concatenate = partial(_simple_aggregation, Concatenate)


@add_metaclass(ABCMeta)
class MonitoredQuantity(object):
"""The base class for monitored-quantities.
Expand Down
Loading

0 comments on commit de941ce

Please sign in to comment.