diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 7558a4b9ba..693d19f898 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -141,7 +141,6 @@ def __init__( self.use_cache = use_cache self.init_config = init_config self.fc_type = fc_type - self.bias = None if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: @@ -232,4 +231,4 @@ def _validate_config(self): if self.ffn_config['ffn_type'] == 'mptmlp': self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': - self.bias = not self.no_bias + self.ffn_config['bias'] = not self.no_bias