diff --git a/NEWS.md b/NEWS.md index e2a60e0..217e8f5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# NERDA 1.0.0 + +* NERDA model class is now equipped with functions for saving (loading) weights for a fine-tuned NERDA Network to (from) file. See functions model.save_network() and model.load_network_from_file() + # NERDA 0.9.7 * return confidence scores for predictions of all tokens, e.g. model.predict(x, return_confidence=True). diff --git a/setup.py b/setup.py index 8dbac5f..78ab795 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="NERDA", - version="0.9.7", + version="1.0.0", author="Lars Kjeldgaard, Lukas Christian Nielsen", author_email="lars.kjeldgaard@eb.dk", description="A Framework for Finetuning Transformers for Named-Entity Recognition", diff --git a/src/NERDA/models.py b/src/NERDA/models.py index bbd8537..00dc3c4 100644 --- a/src/NERDA/models.py +++ b/src/NERDA/models.py @@ -230,8 +230,24 @@ def load_network_from_file(self, model_path = "model.bin") -> str: # TODO: change assert to Raise. assert os.path.exists(model_path), "File does not exist. You can download network with download_network()" self.network.load_state_dict(torch.load(model_path, map_location = torch.device(self.device))) + self.network.device = self.device return f'Weights for network loaded from {model_path}' + def save_network(self, model_path:str = "model.bin") -> None: + """Save Weights of NERDA Network + + Saves weights for a fine-tuned NERDA Network to file. + + Args: + model_path (str, optional): Path for model file. + Defaults to "model.bin". + + Returns: + Nothing. Saves model to file as a side-effect. + """ + torch.save(self.network.state_dict(), model_path) + print(f"Network written to file {model_path}") + def quantize(self): """Apply dynamic quantization to increase performance.