-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add use_pretrained attribute for AutoTransformers #3498
Conversation
Will add a test |
Okay this is probably wrong, going to close and discuss this first. |
@@ -3092,6 +3097,10 @@ def module_name(): | |||
description=ENCODER_METADATA["AutoTransformer"]["type"].long_description, | |||
) | |||
|
|||
# Always set this to True since we always want to use the pretrained weights | |||
# We don't currently support training from scratch for AutoTransformers | |||
use_pretrained: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a property so the user could never modify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@property
def use_pretrained(self) -> bool:
return True
@@ -292,3 +289,42 @@ def test_tfidf_encoder(vocab_size: int): | |||
inputs = torch.randint(2, (batch_size, sequence_length)).to(DEVICE) | |||
outputs = text_encoder(inputs) | |||
assert outputs[ENCODER_OUTPUT].shape[1:] == text_encoder.output_shape | |||
|
|||
|
|||
def test_hf_auto_transformer_use_pretrained(tmpdir, csv_filename): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The True
case would be tested implicitly elsewhere, correct? If not, maybe we could parametrize the test with both cases.
text_feature( | ||
encoder={ | ||
"type": "auto_transformer", | ||
"use_pretrained": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this should be an error if we were more strict with our config validation rules. We should instead just leave this out of the config.
Fixes the following error:
when trying to train an custom transformer model from HF using a config that looks like this: