Skip to content

Commit

Permalink
Update SWA protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 1, 2025
1 parent b2457e8 commit fff9371
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 52 deletions.
117 changes: 67 additions & 50 deletions neurobayes/flax_nets/deterministic_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ class TrainState(train_state.TrainState):
collected_weights: List

class DeterministicNN:
"""
Args:
architecture: a Flax model
input_shape: (n_samples, n_features) or (n_samples, *dims, n_channels)
loss: type of loss, 'homoskedastic' (default) or 'heteroskedastic'
learning_rate: Initial learning rate
map: Uses maximum a posteriori approximation
sigma: Standard deviation for Gaussian prior
swa_config: SWA configuration dictionary
"""
def __init__(self,
architecture: Type[flax.linen.Module],
input_shape: Union[int, Tuple[int]],
Expand All @@ -35,35 +45,52 @@ def __init__(self,
key = jax.random.PRNGKey(0)
params = self.model.init(key, jnp.ones((1, *input_shape)))['params']

# Default SWA configuration
# Default SWA configuration with all required parameters
self.default_swa_config = {
'schedule': 'constant',
'start_pct': 0.95,
'swa_lr': learning_rate, # Same as initial for constant schedule
'cycle_length': None
}

# Update with user config if provided
self.swa_config = {**self.default_swa_config, **(swa_config or {})}

# Create initial state
self.current_lr = learning_rate
self.optimizer = optax.adam(learning_rate)
self.state = TrainState.create(
apply_fn=self.model.apply,
params=params,
tx=optax.adam(learning_rate),
tx=self.optimizer,
batch_stats=None,
collected_weights=[]
)

self.learning_rate = learning_rate
self.map = map
self.sigma = sigma

def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray,
epochs: int, batch_size: int = None) -> None:
self.params_history = []

@partial(jax.jit, static_argnums=(0,))
def train_step(self, state, inputs, targets):
"""JIT-compiled training step"""
loss, grads = jax.value_and_grad(self.total_loss)(state.params, inputs, targets)
state = state.apply_gradients(grads=grads)
return state, loss

def update_learning_rate(self, learning_rate: float):
"""Update the optimizer with a new learning rate"""
if learning_rate != self.current_lr:
self.current_lr = learning_rate
self.state = self.state.replace(tx=optax.adam(learning_rate))

def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray, epochs: int, batch_size: int = None) -> None:
X_train, y_train = self.set_data(X_train, y_train)

if batch_size is None or batch_size >= len(X_train):
batch_size = len(X_train)

# Calculate SWA start epoch
start_epoch = int(epochs * self.swa_config['start_pct'])

Expand All @@ -81,61 +108,51 @@ def train(self, X_train: jnp.ndarray, y_train: jnp.ndarray,
y_batches = split_in_batches(y_train, batch_size)
num_batches = len(X_batches)

collected_weights = []

with tqdm(total=epochs, desc="Training Progress", leave=True) as pbar:
for epoch in range(1, epochs + 1):
# Get learning rate and collection decision for current epoch
current_lr, should_collect = lr_schedule(epoch)
for epoch in range(epochs):
# Get learning rate and collection decision from schedule
learning_rate, should_collect = lr_schedule(epoch)
# Update learning rate if needed
self.update_learning_rate(learning_rate)

epoch_loss = 0.0
for X_batch, y_batch in zip(X_batches, y_batches):
self.state, batch_loss = self.train_step(
self.state, X_batch, y_batch, current_lr
)
for i, (X_batch, y_batch) in enumerate(zip(X_batches, y_batches)):
self.state, batch_loss = self.train_step(self.state, X_batch, y_batch)
epoch_loss += batch_loss

# Collect weights if schedule indicates
# Collect weights if scheduled
if should_collect:
collected_weights.append(self.state.params)
self._store_params(self.state.params)

# Update progress bar
avg_epoch_loss = epoch_loss / num_batches
status = f"Epoch {epoch}/{epochs}, Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.6f}"
if should_collect:
status += " (collected)"
pbar.set_postfix_str(status)
pbar.set_postfix_str(
f"Epoch {epoch+1}/{epochs}, "
f"LR: {learning_rate:.6f}, "
f"Loss: {avg_epoch_loss:.4f} "
)
pbar.update(1)

# Calculate final averaged weights if we collected any
if collected_weights:
self.state = self.state.replace(
params=self.average_params(collected_weights),
collected_weights=collected_weights
)

# Average collected weights if any were collected
if self.params_history:
self.state = self.state.replace(params=self.average_params())

@partial(jax.jit, static_argnums=(0,))
def train_step(self, state, inputs, targets, learning_rate):
"""Single training step with configurable learning rate"""
loss, grads = jax.value_and_grad(self.total_loss)(state.params, inputs, targets)

# Update optimizer learning rate
new_tx = optax.chain(
optax.scale_by_adam(),
optax.scale_by_schedule(lambda _: learning_rate)
)
state = state.replace(tx=new_tx)

# Apply gradients
state = state.apply_gradients(grads=grads)
return state, loss
def _store_params(self, params: Dict) -> None:
self.params_history.append(params)

def average_params(self, params_list: List[Dict]) -> Dict:
"""Average a list of parameter dictionaries"""
return jax.tree_util.tree_map(
lambda *params: jnp.mean(jnp.stack(params), axis=0),
*params_list
def average_params(self) -> Dict:
if not self.params_history:
return self.state.params

# Compute the element-wise average of all stored parameters
avg_params = jax.tree_util.tree_map(
lambda *param_trees: jnp.mean(jnp.stack(param_trees), axis=0),
*self.params_history
)
return avg_params

def reset_swa(self):
"""Reset SWA collections"""
self.params_history = []

def mse_loss(self, params: Dict, inputs: jnp.ndarray,
targets: jnp.ndarray) -> jnp.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions neurobayes/flax_nets/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def cyclic_schedule(epoch: int) -> Tuple[float, bool]:
# Before 75%: constant high learning rate
return initial_lr, False

decay_epochs = int(0.1 * total_epochs) # 10% for decay
decay_epochs = int(0.05 * total_epochs) # 5% for decay
decay_end = start_epoch + decay_epochs

if epoch < decay_end:
Expand Down Expand Up @@ -68,7 +68,7 @@ def linear_schedule(epoch: int) -> Tuple[float, bool]:
# Before SWA: high learning rate, no collection
return initial_lr, False

decay_epochs = int(0.1 * total_epochs) # 10% of epochs for decay
decay_epochs = int(0.05 * total_epochs) # 10% of epochs for decay
decay_end = start_epoch + decay_epochs

if epoch < decay_end:
Expand Down

0 comments on commit fff9371

Please sign in to comment.