Skip to content

Commit

Permalink
Ensure residual connections are in PartialBTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 5, 2025
1 parent 107ac2a commit bced2ef
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions neurobayes/models/partial_btnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def prior(name, shape):
layer_name = config['layer_name']
layer_type = config['layer_type']

# Embeddings have special handling for inputs and outputs
if layer_type == "embedding":
layer = EmbedModule(
features=config['features'],
Expand All @@ -95,19 +94,16 @@ def prior(name, shape):
else: # PosEmbed
pos_embedding = embedding
current_input = token_embedding + pos_embedding

# Layer norms are always deterministic
elif layer_type == "layernorm":
layer = LayerNormModule(layer_name=layer_name)
params = {"params": {layer_name: pretrained_priors[layer_name]}}
current_input = layer.apply(params, current_input)

# Attention needs block_idx for naming
elif layer_type == "attention":
# Save input for residual
residual = current_input

block_idx = int(layer_name.split('_')[0][5:])
layer = TransformerAttentionModule(
num_heads=config['num_heads'],
qkv_features=config['qkv_features'],
dropout_rate=config.get('dropout_rate', 0.1),
layer_name="Attention",
block_idx=block_idx
)
Expand All @@ -119,8 +115,19 @@ def prior(name, shape):
params = {"params": {f"Block{block_idx}_Attention": pretrained_priors[layer_name]}}
current_input = layer.apply(params, current_input, enable_dropout=False)

# MLP/Dense layers
else:
# Add residual after attention
current_input = current_input + residual

elif layer_type == "layernorm":
layer = LayerNormModule(layer_name=layer_name)
params = {"params": {layer_name: pretrained_priors[layer_name]}}
current_input = layer.apply(params, current_input)

# Save residual after first layer norm in each block
if layer_name.endswith('LayerNorm1'):
residual = current_input

else: # fc layers
layer = MLPLayerModule(
features=config['features'],
activation=config.get('activation'),
Expand All @@ -137,8 +144,11 @@ def prior(name, shape):
else:
params = {"params": {layer_name: pretrained_priors[layer_name]}}
current_input = layer.apply(params, current_input, enable_dropout=False)

# Add residual after second dense layer in each block
if layer_name.endswith('dense2'):
current_input = current_input + residual

# Output processing
current_input = jnp.mean(current_input, axis=1)

if self.is_regression:
Expand Down

0 comments on commit bced2ef

Please sign in to comment.