Skip to content

Commit

Permalink
Made the docs numpydoc compliant.
Browse files Browse the repository at this point in the history
Modified inference procedure to have a method "is_rbm_compatible" that will raise a NotImplementedError if not appropriate for RBM.
Removed assert for RBM and UpDown inference.
  • Loading branch information
rdevon committed Nov 7, 2014
1 parent 35c4e2b commit 3f7eab9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
56 changes: 38 additions & 18 deletions pylearn2/models/dbm/dbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,14 @@ def energy(self, V, hidden):

def mf(self, *args, **kwargs):
"""
.. todo::
Mean field inference of model.
WRITEME
Performs the inference procedure on the model.
Parameters
----------
*args: TODO
**kwargs: TODO
"""

self.setup_inference_procedure()
Expand Down Expand Up @@ -218,7 +223,6 @@ def setup_rng(self):
def setup_inference_procedure(self):
"""
Sets up the inference procedure for the DBM.
If the number of hidden layers is one, then use UpDown.
"""
if not hasattr(self, 'inference_procedure') or \
self.inference_procedure is None:
Expand All @@ -227,11 +231,15 @@ def setup_inference_procedure(self):
else:
self.inference_procedure = WeightDoubling()
self.inference_procedure.set_dbm(self)
elif len(self.hidden_layers) == 1:
assert isinstance(self.inference_procedure, UpDown),\
"A DBM with a single layer (a.k.a an RBM) should use %r"\
"as the inference_procedure, "\
"not %r" %(UpDown, type(self.inference_procedure))

if len(self.hidden_layers) == 1:
try:
self.inference_procedure.is_rbm_compatible()
except NotImplementedError:
warnings.warn("Inference procedure %r may have unexpected"
"behavior when used with one hidden layer (RBM)."
"See models/dbn/inference_procedure.py for"
"details." % type(self.inference_procedure))

def setup_sampling_procedure(self):
"""
Expand Down Expand Up @@ -294,9 +302,11 @@ def add_layers(self, layers):

def freeze(self, parameter_set):
"""
.. todo::
Freezes the set of parameters.
WRITEME
Parameters
----------
parameter_set: WRITEME
"""
# patch old pickle files
if not hasattr(self, 'freeze_set'):
Expand Down Expand Up @@ -338,6 +348,11 @@ def get_params(self):
def set_batch_size(self, batch_size):
"""
Sets the batch size of the DBM.
Parameters
----------
batch_size: int
The batch size
"""
self.batch_size = batch_size
self.force_batch_size = batch_size
Expand Down Expand Up @@ -387,9 +402,8 @@ def get_lr_scalers(self):
def get_weights(self):
"""
Returns the weights of the bottom hidden layer.
TODO: add visualization of higher levels.
"""

return self.hidden_layers[0].get_weights()

def get_weights_view_shape(self):
Expand Down Expand Up @@ -587,7 +601,7 @@ def get_monitoring_channels(self, data):
This is done through the visible and all of the hidden layers of DBM.
Parameters
-----------
----------
data: tensor-like
Data from which to evaluate model.
"""
Expand Down Expand Up @@ -666,7 +680,7 @@ def reconstruct(self, V):
Reconstructs an input using inpainting method.
Parameters
-----------
----------
V: tensor-like
Input sample.
Expand All @@ -689,20 +703,25 @@ def reconstruct(self, V):

def do_inpainting(self, *args, **kwargs):
"""
.. todo::
Perform inpainting on model.
WRITEME
Inpainting is defined by the inference procedure.
Parameters
----------
*args: WRITEME
**kwargs: WRITEME
"""
self.setup_inference_procedure()
return self.inference_procedure.do_inpainting(*args, **kwargs)

def initialize_chains(self, X, Y, theano_rng):
"""
Function to initialize chains for model when performing the neg phase.
TODO: implement when actually getting gradients.
TODO: implement in cost functions.
Parameters
-----------
----------
X: tensor-like
The data. If none, then persistent (TODO)
Y: tensor-like
Expand Down Expand Up @@ -745,6 +764,7 @@ def initialize_chains(self, X, Y, theano_rng):
num_steps=1)
return layer_to_chains


class RBM(DBM):
"""
A restricted Boltzmann machine.
Expand Down
4 changes: 4 additions & 0 deletions pylearn2/models/dbm/inference_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,4 +1539,8 @@ def is_rbm_compatible(self):
"""
Is implemented as UpDown is RBM compatible.
"""
<<<<<<< HEAD
return True
=======
return
>>>>>>> 20fa672... Made the docs numpydoc compliant.

0 comments on commit 3f7eab9

Please sign in to comment.