diff --git a/.gitignore b/.gitignore index 552c6d13..5de52498 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ +.DS_Store \ No newline at end of file diff --git a/libra/plotting/generate_plots.py b/libra/plotting/generate_plots.py index 8ddc99fb..2871691d 100644 --- a/libra/plotting/generate_plots.py +++ b/libra/plotting/generate_plots.py @@ -140,7 +140,7 @@ def generate_regression_plots(history, data, label): return plots, plot_names -def generate_classification_plots(history, data, label, model, X_test, y_test): +def generate_classification_plots(history): ''' plotting function that generates classification plots diff --git a/libra/queries.py b/libra/queries.py index 20a89e43..e6cb346c 100644 --- a/libra/queries.py +++ b/libra/queries.py @@ -760,7 +760,8 @@ def convolutional_query(self, width=None, show_feature_map=False, save_as_tfjs=None, - save_as_tflite=None): + save_as_tflite=None, + generate_plots=None): ''' Calls the body of the convolutional neural network query which is located in the feedforward.py file :param instruction: The objective that you want to model (str). @@ -798,7 +799,8 @@ def convolutional_query(self, height=height, width=width, save_as_tfjs=save_as_tfjs, - save_as_tflite=save_as_tflite) + save_as_tflite=save_as_tflite, + generate_plots=generate_plots) if show_feature_map: model = self.models["convolutional_NN"]["model"] diff --git a/libra/query/feedforward_nn.py b/libra/query/feedforward_nn.py index fd093bea..d71e2b92 100644 --- a/libra/query/feedforward_nn.py +++ b/libra/query/feedforward_nn.py @@ -458,7 +458,7 @@ def classification_ann(instruction, plots = {} if generate_plots: plots = generate_classification_plots( - models[len(models) - 1], data, y, model, X_test, y_test) + models[len(models) - 1]) if save_model: save(final_model, save_model, save_path) @@ -508,7 +508,8 @@ def convolutional(instruction=None, height=None, width=None, save_as_tfjs=None, - save_as_tflite=None): + save_as_tflite=None, + generate_plots=True): ''' Body of the convolutional function used that is called in the neural network query if the data is presented in images. @@ -787,6 +788,26 @@ def convolutional(instruction=None, epochs=epochs, verbose=verbose) + models = [] + losses = [] + accuracies = [] + model_data = [] + + model_data.append(model) + models.append(history) + + losses.append(history.history["val_loss"] + [len(history.history["val_loss"]) - 1]) + accuracies.append(history.history['val_accuracy'] + [len(history.history['val_accuracy']) - 1]) + + # final_model = model_data[accuracies.index(max(accuracies))] + # final_hist = models[accuracies.index(max(accuracies))] + + plots = {} + if generate_plots: + plots = generate_classification_plots(models[len(models) - 1]) + logger('->', 'Final training accuracy: {}'.format(history.history['accuracy'][len(history.history['accuracy']) - 1])) logger('->', 'Final validation accuracy: {}'.format( @@ -816,6 +837,7 @@ def convolutional(instruction=None, 'data': {'train': X_train, 'test': X_test}, 'shape': input_shape, "model": model, + "plots": plots, 'losses': { 'training_loss': history.history['loss'], 'val_loss': history.history['val_loss']}, diff --git a/libra/query/nlp_queries.py b/libra/query/nlp_queries.py index 2d6f0388..640b7546 100644 --- a/libra/query/nlp_queries.py +++ b/libra/query/nlp_queries.py @@ -207,8 +207,7 @@ def text_classification_query(self, instruction, drop=None, # generates appropriate classification plots by feeding all # information logger("Generating plots") - plots = generate_classification_plots( - history, X, Y, model, X_test, y_test) + plots = generate_classification_plots(history) if save_model: save(model, save_model, save_path=save_path) diff --git a/libra/query/supplementaries.py b/libra/query/supplementaries.py index fe38cac4..03a44fb7 100644 --- a/libra/query/supplementaries.py +++ b/libra/query/supplementaries.py @@ -177,8 +177,7 @@ def tune_helper( logger("->", 'Best Hyperparameters Found: {}'.format(returned_pms.values)) if generate_plots: logger("Generating updated plots") - plots = generate_classification_plots( - history, data, target_column, returned_model, X_test, y_test) + plots = generate_classification_plots(history) logger("Re-stored model under 'classification_ANN' key")