From 0513a2fd7141a7de281b9d84ae0d43c62159cd53 Mon Sep 17 00:00:00 2001 From: emailweixu Date: Fri, 27 Aug 2021 11:33:29 -0700 Subject: [PATCH 1/2] Refactor critic networks using the new encoding network Also fixed a copy of EncodingNetwork.make_parallel(), where the copy is not correctly implemented. --- alf/layers.py | 62 ++++++---- alf/networks/critic_networks.py | 160 +++++-------------------- alf/networks/critic_networks_test.py | 14 +-- alf/networks/encoding_networks.py | 15 ++- alf/networks/encoding_networks_test.py | 5 + alf/networks/network_test.py | 12 ++ 6 files changed, 102 insertions(+), 166 deletions(-) diff --git a/alf/layers.py b/alf/layers.py index 54b9009dd..93288fbf1 100644 --- a/alf/layers.py +++ b/alf/layers.py @@ -1352,8 +1352,13 @@ def __init__(self, use_bias = not use_bn self._activation = activation self._n = n + self._use_bias = use_bias self._in_channels = in_channels self._out_channels = out_channels + self._kernel_initializer = kernel_initializer + self._kernel_init_gain = kernel_init_gain + self._bias_init_value = bias_init_value + self._kernel_size = common.tuplify2d(kernel_size) self._conv2d = nn.Conv2d( in_channels * n, @@ -1364,34 +1369,30 @@ def __init__(self, padding=padding, bias=use_bias) - for i in range(n): - if kernel_initializer is None: + if use_bn: + self._bn = nn.BatchNorm2d(n * out_channels) + else: + self._bn = None + self.reset_parameters() + + def reset_parameters(self): + for i in range(self._n): + if self._kernel_initializer is None: variance_scaling_init( - self._conv2d.weight.data[i * out_channels:(i + 1) * - out_channels], - gain=kernel_init_gain, + self._conv2d.weight.data[i * self._out_channels:(i + 1) * + self._out_channels], + gain=self._kernel_init_gain, nonlinearity=self._activation) else: - kernel_initializer( - self._conv2d.weight.data[i * out_channels:(i + 1) * - out_channels]) - - # [n*C', C, kernel_size, kernel_size]->[n, C', C, kernel_size, kernel_size] - self._weight = self._conv2d.weight.view( - self._n, self._out_channels, self._in_channels, - self._kernel_size[0], self._kernel_size[1]) + self._kernel_initializer( + self._conv2d.weight.data[i * self._out_channels:(i + 1) * + self._out_channels]) - if use_bias: - nn.init.constant_(self._conv2d.bias.data, bias_init_value) - # [n*C']->[n, C'] - self._bias = self._conv2d.bias.view(self._n, self._out_channels) - else: - self._bias = None + if self._use_bias: + nn.init.constant_(self._conv2d.bias.data, self._bias_init_value) - if use_bn: - self._bn = nn.BatchNorm2d(n * out_channels) - else: - self._bn = None + if self._bn: + self._bn.reset_parameters() def forward(self, img): """Forward @@ -1454,11 +1455,22 @@ def forward(self, img): @property def weight(self): - return self._weight + # The reason that weight cannot pre-computed at __init__ is deepcopy will + # fail. deepcopy is needed to implement the copy for the container networks. + # [n*C', C, kernel_size, kernel_size]->[n, C', C, kernel_size, kernel_size] + return self._conv2d.weight.view( + self._n, self._out_channels, self._in_channels, + self._kernel_size[0], self._kernel_size[1]) @property def bias(self): - return self._bias + if self._use_bias: + # The reason that weight cannot pre-computed at __init__ is deepcopy will + # fail. deepcopy is needed to implement the copy for the container networks. + # [n*C']->[n, C'] + return self._conv2d.bias.view(self._n, self._out_channels) + else: + return None @alf.configurable diff --git a/alf/networks/critic_networks.py b/alf/networks/critic_networks.py index ec8b152d6..dc7f697a2 100644 --- a/alf/networks/critic_networks.py +++ b/alf/networks/critic_networks.py @@ -17,7 +17,6 @@ import math import torch -import torch.nn as nn import alf import alf.utils.math_ops as math_ops @@ -25,7 +24,6 @@ from alf.initializers import variance_scaling_init from alf.tensor_specs import TensorSpec -from .network import Network from .encoding_networks import EncodingNetwork, LSTMEncodingNetwork, ParallelEncodingNetwork @@ -54,12 +52,15 @@ def _check_individual(spec, proc): @alf.configurable -class CriticNetwork(Network): +class CriticNetwork(EncodingNetwork): """Creates an instance of ``CriticNetwork`` for estimating action-value of continuous or discrete actions. The action-value is defined as the expected return starting from the given input observation and taking the given action. This module takes observation as input and action as input and outputs an action-value tensor with the shape of ``[batch_size]``. + + The network take a tuple of (observation, action) as input to computes the + action-value given an observation. """ def __init__(self, @@ -119,8 +120,6 @@ def __init__(self, situation. name (str): """ - super().__init__(input_tensor_spec, name=name) - if kernel_initializer is None: kernel_initializer = functools.partial( variance_scaling_init, @@ -130,7 +129,7 @@ def __init__(self, observation_spec, action_spec = input_tensor_spec - self._obs_encoder = EncodingNetwork( + obs_encoder = EncodingNetwork( observation_spec, input_preprocessors=observation_input_processors, preprocessing_combiner=observation_preprocessing_combiner, @@ -139,12 +138,12 @@ def __init__(self, activation=activation, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, - name=self.name + ".obs_encoder") + name=name + ".obs_encoder") _check_action_specs_for_critic_networks(action_spec, action_input_processors, action_preprocessing_combiner) - self._action_encoder = EncodingNetwork( + action_encoder = EncodingNetwork( action_spec, input_preprocessors=action_input_processors, preprocessing_combiner=action_preprocessing_combiner, @@ -152,14 +151,16 @@ def __init__(self, activation=activation, kernel_initializer=kernel_initializer, use_fc_bn=use_fc_bn, - name=self.name + ".action_encoder") + name=name + ".action_encoder") last_kernel_initializer = functools.partial( torch.nn.init.uniform_, a=-0.003, b=0.003) - self._joint_encoder = EncodingNetwork( - TensorSpec((self._obs_encoder.output_spec.shape[0] + - self._action_encoder.output_spec.shape[0], )), + super().__init__( + input_tensor_spec=input_tensor_spec, + output_tensor_spec=output_tensor_spec, + input_preprocessors=(obs_encoder, action_encoder), + preprocessing_combiner=alf.layers.NestConcat(dim=-1), fc_layer_params=joint_fc_layer_params, activation=activation, kernel_initializer=kernel_initializer, @@ -167,35 +168,11 @@ def __init__(self, last_activation=math_ops.identity, use_fc_bn=use_fc_bn, last_kernel_initializer=last_kernel_initializer, - name=self.name + ".joint_encoder") - + name=name) self._use_naive_parallel_network = use_naive_parallel_network - self._output_spec = output_tensor_spec - - def forward(self, inputs, state=()): - """Computes action-value given an observation. - - Args: - inputs: A tuple of Tensors consistent with ``input_tensor_spec`` - state: empty for API consistent with ``CriticRNNNetwork`` - - Returns: - tuple: - - action_value (torch.Tensor): a tensor of the size ``[batch_size]`` - - state: empty - """ - observations, actions = inputs - - encoded_obs, _ = self._obs_encoder(observations) - encoded_action, _ = self._action_encoder(actions) - joint = torch.cat([encoded_obs, encoded_action], -1) - action_value, _ = self._joint_encoder(joint) - action_value = action_value.reshape(action_value.shape[0], - *self._output_spec.shape) - return action_value, state def make_parallel(self, n): - """Create a ``ParallelCriticNetwork`` using ``n`` replicas of ``self``. + """Create a parallel critic network using ``n`` replicas of ``self``. The initialized network parameters will be different. If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork`` to create the parallel network. @@ -203,60 +180,11 @@ def make_parallel(self, n): if self._use_naive_parallel_network: return alf.networks.NaiveParallelNetwork(self, n) else: - return ParallelCriticNetwork(self, n, "parallel_" + self._name) - - -class ParallelCriticNetwork(Network): - """Perform ``n`` critic computations in parallel.""" - - def __init__(self, - critic_network: CriticNetwork, - n: int, - name="ParallelCriticNetwork"): - """ - It create a parallelized version of ``critic_network``. - - Args: - critic_network (CriticNetwork): non-parallelized critic network - n (int): make ``n`` replicas from ``critic_network`` with different - initialization. - name (str): - """ - super().__init__( - input_tensor_spec=critic_network.input_tensor_spec, name=name) - self._obs_encoder = critic_network._obs_encoder.make_parallel(n, True) - self._action_encoder = critic_network._action_encoder.make_parallel( - n, True) - self._joint_encoder = critic_network._joint_encoder.make_parallel(n) - self._output_spec = TensorSpec((n, ) + - critic_network.output_spec.shape) - - def forward(self, inputs, state=()): - """Computes action-value given an observation. - - Args: - inputs (tuple): A tuple of Tensors consistent with `input_tensor_spec``. - state (tuple): Empty for API consistent with ``CriticRNNNetwork``. - - Returns: - tuple: - - action_value (torch.Tensor): a tensor of shape :math:`[B,n]`, where - :math:`B` is the batch size. - - state: empty - """ - observations, actions = inputs - - encoded_obs, _ = self._obs_encoder(observations) - encoded_action, _ = self._action_encoder(actions) - joint = torch.cat([encoded_obs, encoded_action], -1) - action_value, _ = self._joint_encoder(joint) - action_value = action_value.reshape(action_value.shape[0], - *self._output_spec.shape) - return action_value, state + return super().make_parallel(n, True) @alf.configurable -class CriticRNNNetwork(Network): +class CriticRNNNetwork(LSTMEncodingNetwork): """Creates an instance of ``CriticRNNNetwork`` for estimating action-value of continuous or discrete actions. The action-value is defined as the expected return starting from the given inputs (observation and state) and @@ -318,8 +246,6 @@ def __init__(self, with uniform distribution will be used. name (str): """ - super().__init__(input_tensor_spec, name=name) - if kernel_initializer is None: kernel_initializer = functools.partial( variance_scaling_init, @@ -329,7 +255,7 @@ def __init__(self, observation_spec, action_spec = input_tensor_spec - self._obs_encoder = EncodingNetwork( + obs_encoder = EncodingNetwork( observation_spec, input_preprocessors=observation_input_processors, preprocessing_combiner=observation_preprocessing_combiner, @@ -341,7 +267,7 @@ def __init__(self, _check_action_specs_for_critic_networks(action_spec, action_input_processors, action_preprocessing_combiner) - self._action_encoder = EncodingNetwork( + action_encoder = EncodingNetwork( action_spec, input_preprocessors=action_input_processors, preprocessing_combiner=action_preprocessing_combiner, @@ -349,18 +275,15 @@ def __init__(self, activation=activation, kernel_initializer=kernel_initializer) - self._joint_encoder = EncodingNetwork( - TensorSpec((self._obs_encoder.output_spec.shape[0] + - self._action_encoder.output_spec.shape[0], )), - fc_layer_params=joint_fc_layer_params, - activation=activation, - kernel_initializer=kernel_initializer) - last_kernel_initializer = functools.partial( torch.nn.init.uniform_, a=-0.003, b=0.003) - self._lstm_encoding_net = LSTMEncodingNetwork( - input_tensor_spec=self._joint_encoder.output_spec, + super().__init__( + input_tensor_spec=input_tensor_spec, + output_tensor_spec=output_tensor_spec, + input_preprocessors=(obs_encoder, action_encoder), + preprocessing_combiner=alf.layers.NestConcat(dim=-1), + pre_fc_layer_params=joint_fc_layer_params, hidden_size=lstm_hidden_size, post_fc_layer_params=critic_fc_layer_params, activation=activation, @@ -369,31 +292,10 @@ def __init__(self, last_activation=math_ops.identity, last_kernel_initializer=last_kernel_initializer) - self._output_spec = output_tensor_spec - - def forward(self, inputs, state): - """Computes action-value given an observation. - - Args: - inputs: A tuple of Tensors consistent with ``input_tensor_spec`` - state (nest[tuple]): a nest structure of state tuples ``(h, c)`` - - Returns: - tuple: - - action_value (torch.Tensor): a tensor of the size ``[batch_size]`` - - new_state (nest[tuple]): the updated states + def make_parallel(self, n): + """Create a parallel critic RNN network using ``n`` replicas of ``self``. + The initialized network parameters will be different. + If ``use_naive_parallel_network`` is True, use ``NaiveParallelNetwork`` + to create the parallel network. """ - observations, actions = inputs - - encoded_obs, _ = self._obs_encoder(observations) - encoded_action, _ = self._action_encoder(actions) - joint = torch.cat([encoded_obs, encoded_action], -1) - encoded_joint, _ = self._joint_encoder(joint) - action_value, state = self._lstm_encoding_net(encoded_joint, state) - action_value = action_value.reshape(action_value.shape[0], - *self._output_spec.shape) - return action_value, state - - @property - def state_spec(self): - return self._lstm_encoding_net.state_spec + return super().make_parallel(n, True) diff --git a/alf/networks/critic_networks_test.py b/alf/networks/critic_networks_test.py index cd487975a..cd5854cad 100644 --- a/alf/networks/critic_networks_test.py +++ b/alf/networks/critic_networks_test.py @@ -21,8 +21,9 @@ import alf from alf.tensor_specs import TensorSpec, BoundedTensorSpec -from alf.networks import CriticNetwork, CriticRNNNetwork, ParallelCriticNetwork +from alf.networks import CriticNetwork, CriticRNNNetwork from alf.networks.network import NaiveParallelNetwork +from alf.networks.network_test import test_net_copy from alf.networks.preprocessors import EmbeddingPreprocessor from alf.nest.utils import NestConcat @@ -49,7 +50,7 @@ def _init(self, lstm_hidden_size): return network_ctor, state @parameterized.parameters((100, ), (None, ), ((200, 100), )) - def test_critic(self, lstm_hidden_size): + def test_critic(self, lstm_hidden_size=(200, 100)): obs_spec = TensorSpec((3, 20, 20), torch.float32) action_spec = TensorSpec((5, ), torch.float32) input_spec = (obs_spec, action_spec) @@ -70,6 +71,7 @@ def test_critic(self, lstm_hidden_size): observation_conv_layer_params=observation_conv_layer_params, action_fc_layer_params=action_fc_layer_params, joint_fc_layer_params=joint_fc_layer_params) + test_net_copy(critic_net) value, state = critic_net._test_forward() self.assertEqual(value.shape, (1, )) @@ -84,6 +86,7 @@ def test_critic(self, lstm_hidden_size): # test make_parallel pnet = critic_net.make_parallel(6) + test_net_copy(pnet) if lstm_hidden_size is not None: # shape of state should be [B, n, ...] @@ -92,11 +95,6 @@ def test_critic(self, lstm_hidden_size): state = alf.nest.map_structure( lambda x: x.unsqueeze(1).expand(x.shape[0], 6, x.shape[1]), state) - if lstm_hidden_size is None: - self.assertTrue(isinstance(pnet, ParallelCriticNetwork)) - else: - self.assertTrue(isinstance(pnet, NaiveParallelNetwork)) - value, state = pnet(network_input, state) self.assertEqual(pnet.output_spec, TensorSpec((6, ))) self.assertEqual(value.shape, (1, 6)) @@ -135,7 +133,7 @@ def _train(pnet, name): pnet = critic_net.make_parallel(replicas) _train(pnet, "ParallelCriticNetwork") - pnet = alf.networks.network.NaiveParallelNetwork(critic_net, replicas) + pnet = NaiveParallelNetwork(critic_net, replicas) _train(pnet, "NaiveParallelNetwork") @parameterized.parameters((CriticNetwork, ), (CriticRNNNetwork, )) diff --git a/alf/networks/encoding_networks.py b/alf/networks/encoding_networks.py index 6fde05ce2..f511ba3d6 100644 --- a/alf/networks/encoding_networks.py +++ b/alf/networks/encoding_networks.py @@ -419,15 +419,16 @@ def make_parallel(self, n: int, allow_non_parallel_input=False): """ pnet = super().make_parallel(n) if allow_non_parallel_input: - return _ReplicateInputForParallel(self.input_tensor_spec, n, pnet) + return _ReplicateInputForParallel( + self.input_tensor_spec, n, pnet, name=pnet.name) else: return pnet class _ReplicateInputForParallel(Network): - def __init__(self, input_tensor_spec, n, pnet): + def __init__(self, input_tensor_spec, n, pnet, name): super().__init__( - input_tensor_spec, state_spec=pnet.state_spec, name=pnet.name) + input_tensor_spec, state_spec=pnet.state_spec, name=name) self._input_tensor_spec = input_tensor_spec self._n = n self._pnet = pnet @@ -438,6 +439,11 @@ def forward(self, inputs, state=()): inputs = alf.layers.make_parallel_input(inputs, self._n) return self._pnet(inputs, state) + def copy(self, name=None): + pnet = self._pnet.copy(name) + return _ReplicateInputForParallel(self.input_tensor_spec, self._n, + pnet, pnet.name) + @alf.configurable def ParallelEncodingNetwork(input_tensor_spec, @@ -695,6 +701,7 @@ def make_parallel(self, n: int, allow_non_parallel_input=False): """ pnet = super().make_parallel(n) if allow_non_parallel_input: - return _ReplicateInputForParallel(self.input_tensor_spec, n, pnet) + return _ReplicateInputForParallel( + self.input_tensor_spec, n, pnet, name=pnet.name) else: return pnet diff --git a/alf/networks/encoding_networks_test.py b/alf/networks/encoding_networks_test.py index dcea37dd5..d4ab0cf64 100644 --- a/alf/networks/encoding_networks_test.py +++ b/alf/networks/encoding_networks_test.py @@ -26,6 +26,7 @@ from alf.networks.encoding_networks import EncodingNetwork from alf.networks.encoding_networks import ParallelEncodingNetwork from alf.networks.encoding_networks import LSTMEncodingNetwork +from alf.networks.network_test import test_net_copy from alf.networks.preprocessors import EmbeddingPreprocessor from alf.tensor_specs import TensorSpec, BoundedTensorSpec from alf.utils import common, math_ops @@ -213,6 +214,8 @@ def test_encoding_network_nested_input(self, lstm): input_tensor_spec=input_spec, input_preprocessors=input_preprocessors, preprocessing_combiner=NestConcat()) + test_net_copy(network) + output, _ = network(imgs, state=[(), (torch.zeros((1, 100)), ) * 2]) if lstm: @@ -266,6 +269,8 @@ def _benchmark(pnet, name): (replicas, *output_spec.shape)) pnet = network.make_parallel(replicas, True) + test_net_copy(pnet) + self.assertEqual(len(list(pnet.parameters())), num_layers * 2) _benchmark(pnet, "ParallelEncodingNetwork") self.assertEqual(pnet.name, "parallel_" + network.name) diff --git a/alf/networks/network_test.py b/alf/networks/network_test.py index c83996c39..c4771b83a 100644 --- a/alf/networks/network_test.py +++ b/alf/networks/network_test.py @@ -26,6 +26,18 @@ from alf.networks.network import NaiveParallelNetwork +def test_net_copy(net): + """Test whether net.copy() is correctly implemented""" + new_net = net.copy() + params = dict(net.named_parameters()) + new_params = dict(new_net.named_parameters()) + for n, p in new_params.items(): + assert p.shape == params[n].shape + assert id(p) != id( + params[n]), ("The parameter of the copied parameter " + "is the same parameter of the original network") + + class BaseNetwork(alf.networks.Network): def __init__(self, v1, **kwargs): super().__init__(v1, **kwargs) From 84492921cecf7cf3d35772e7fd5c19b632f6bb3c Mon Sep 17 00:00:00 2001 From: emailweixu Date: Mon, 30 Aug 2021 13:31:51 -0700 Subject: [PATCH 2/2] Address comment --- alf/networks/critic_networks_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alf/networks/critic_networks_test.py b/alf/networks/critic_networks_test.py index cd5854cad..52cd53083 100644 --- a/alf/networks/critic_networks_test.py +++ b/alf/networks/critic_networks_test.py @@ -50,7 +50,7 @@ def _init(self, lstm_hidden_size): return network_ctor, state @parameterized.parameters((100, ), (None, ), ((200, 100), )) - def test_critic(self, lstm_hidden_size=(200, 100)): + def test_critic(self, lstm_hidden_size): obs_spec = TensorSpec((3, 20, 20), torch.float32) action_spec = TensorSpec((5, ), torch.float32) input_spec = (obs_spec, action_spec)