From bced2efe4ab3c1d8f4a5785fd57a5a684b4347ab Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov Date: Wed, 5 Feb 2025 15:35:40 -0800 Subject: [PATCH] Ensure residual connections are in PartialBTNN --- neurobayes/models/partial_btnn.py | 32 ++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/neurobayes/models/partial_btnn.py b/neurobayes/models/partial_btnn.py index 65d1187..628475c 100644 --- a/neurobayes/models/partial_btnn.py +++ b/neurobayes/models/partial_btnn.py @@ -73,7 +73,6 @@ def prior(name, shape): layer_name = config['layer_name'] layer_type = config['layer_type'] - # Embeddings have special handling for inputs and outputs if layer_type == "embedding": layer = EmbedModule( features=config['features'], @@ -95,19 +94,16 @@ def prior(name, shape): else: # PosEmbed pos_embedding = embedding current_input = token_embedding + pos_embedding - - # Layer norms are always deterministic - elif layer_type == "layernorm": - layer = LayerNormModule(layer_name=layer_name) - params = {"params": {layer_name: pretrained_priors[layer_name]}} - current_input = layer.apply(params, current_input) - # Attention needs block_idx for naming elif layer_type == "attention": + # Save input for residual + residual = current_input + block_idx = int(layer_name.split('_')[0][5:]) layer = TransformerAttentionModule( num_heads=config['num_heads'], qkv_features=config['qkv_features'], + dropout_rate=config.get('dropout_rate', 0.1), layer_name="Attention", block_idx=block_idx ) @@ -119,8 +115,19 @@ def prior(name, shape): params = {"params": {f"Block{block_idx}_Attention": pretrained_priors[layer_name]}} current_input = layer.apply(params, current_input, enable_dropout=False) - # MLP/Dense layers - else: + # Add residual after attention + current_input = current_input + residual + + elif layer_type == "layernorm": + layer = LayerNormModule(layer_name=layer_name) + params = {"params": {layer_name: pretrained_priors[layer_name]}} + current_input = layer.apply(params, current_input) + + # Save residual after first layer norm in each block + if layer_name.endswith('LayerNorm1'): + residual = current_input + + else: # fc layers layer = MLPLayerModule( features=config['features'], activation=config.get('activation'), @@ -137,8 +144,11 @@ def prior(name, shape): else: params = {"params": {layer_name: pretrained_priors[layer_name]}} current_input = layer.apply(params, current_input, enable_dropout=False) + + # Add residual after second dense layer in each block + if layer_name.endswith('dense2'): + current_input = current_input + residual - # Output processing current_input = jnp.mean(current_input, axis=1) if self.is_regression: