Skip to content

Commit

Permalink
Use correct dtype for transformer trainer init
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 4, 2025
1 parent 88fbb08 commit 8024315
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion neurobayes/flax_nets/deterministic_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,21 @@ def __init__(self,

input_shape = (input_shape,) if isinstance(input_shape, int) else input_shape
self.model = architecture

is_transformer = any(base.__name__.lower().find('transformer') >= 0
for base in architecture.__mro__)
input_dtype = jnp.int32 if is_transformer else jnp.float32

if loss not in ['homoskedastic', 'heteroskedastic', 'classification']:
raise ValueError("Select between 'homoskedastic', 'heteroskedastic', or 'classification' loss")
self.loss = loss

# Initialize model
key = jax.random.PRNGKey(0)
params = self.model.init(key, jnp.ones((1, *input_shape)))['params']
params = self.model.init(
key,
jnp.ones((1, *input_shape), dtype=input_dtype)
)['params']

# Default SWA configuration with all required parameters
self.default_swa_config = {
Expand Down

0 comments on commit 8024315

Please sign in to comment.