-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError when initializing StigmaBertModel and a quick fix #3
Comments
Thanks so much for bringing this error to my attention! It seems like the issue is related to versions of either torch or transformers. I'll implement your suggested fix shortly, but want to make sure I can reproduce the issue first so I can confirm that versioning is the issue. Can you let me know which version of python, |
Thank you so much for your attention. Initialize BERT Model (Note the alternative option for specifying model parameters)bert_model = StigmaBertModel(model=settings.MODELS["mimic-iv-discharge_clinical-bert"]["tasks"][keyword_category], I believe this is just a minor and common issue if we install the latest version of transformers. Related post: openai/gpt-2-output-dataset#35 |
Thanks! I was able to reproduce the issue with that added information and have pushed the hot fix you suggested. |
Thank you so much for the great work!
For anyone who encountered the following error when initializing StigmaBertModel:
Traceback (most recent call last):
File "C:\Users\Louis\miniconda3\envs\SL-infer\lib\code.py", line 90, in runcode
exec(code, self.locals)
File "", line 2, in
File "C:\Users\Louis\PycharmProjects\pythonProject5\ehr-stigma-main\stigma\api.py", line 381, in init
_ = self._initialize_model(self._model)
File "C:\Users\Louis\PycharmProjects\pythonProject5\ehr-stigma-main\stigma\api.py", line 485, in _initialize_model
_ = self.model_dict["classifier"].load_state_dict(torch.load(f"{model}/model.pth", map_location=torch.device('cpu')))
File "C:\Users\Louis\miniconda3\envs\SL-infer\lib\site-packages\torch\nn\modules\module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BERTMultitaskClassifier:
Unexpected key(s) in state_dict: "_bert._bert.embeddings.position_ids".
A quick fix that works for me is to add strict=False at line #485 of the stigma/api.py, i.e.:
_ = self.model_dict["classifier"].load_state_dict(torch.load(f"{model}/model.pth", map_location=torch.device('cpu')), strict=False)
I hope this helps! Thanks!
The text was updated successfully, but these errors were encountered: