Skip to content

Commit

Permalink
including functions for saving and loading weights for network
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed Aug 25, 2021
1 parent 7624e45 commit 29a0ab5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# NERDA 1.0.0

* NERDA model class is now equipped with functions for saving (loading) weights for a fine-tuned NERDA Network to (from) file. See functions model.save_network() and model.load_network_from_file()

# NERDA 0.9.7

* return confidence scores for predictions of all tokens, e.g. model.predict(x, return_confidence=True).
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="NERDA",
version="0.9.7",
version="1.0.0",
author="Lars Kjeldgaard, Lukas Christian Nielsen",
author_email="[email protected]",
description="A Framework for Finetuning Transformers for Named-Entity Recognition",
Expand Down
16 changes: 16 additions & 0 deletions src/NERDA/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,24 @@ def load_network_from_file(self, model_path = "model.bin") -> str:
# TODO: change assert to Raise.
assert os.path.exists(model_path), "File does not exist. You can download network with download_network()"
self.network.load_state_dict(torch.load(model_path, map_location = torch.device(self.device)))
self.network.device = self.device
return f'Weights for network loaded from {model_path}'

def save_network(self, model_path:str = "model.bin") -> None:
"""Save Weights of NERDA Network
Saves weights for a fine-tuned NERDA Network to file.
Args:
model_path (str, optional): Path for model file.
Defaults to "model.bin".
Returns:
Nothing. Saves model to file as a side-effect.
"""
torch.save(self.network.state_dict(), model_path)
print(f"Network written to file {model_path}")

def quantize(self):
"""Apply dynamic quantization to increase performance.
Expand Down

0 comments on commit 29a0ab5

Please sign in to comment.