diff --git a/libra/queries.py b/libra/queries.py index 0f5d0034..055441b5 100644 --- a/libra/queries.py +++ b/libra/queries.py @@ -891,11 +891,10 @@ def text_classification_query(self, instruction, label_column=None, drop=None, test_size=0.2, random_state=49, learning_rate=1e-2, - epochs=20, + epochs=5, monitor="val_loss", batch_size=32, - # max_text_length=200, - max_features=2, + max_text_length=20, generate_plots=True, save_model=False, save_path=os.getcwd()): @@ -927,8 +926,7 @@ def text_classification_query(self, instruction, label_column=None, drop=None, monitor=monitor, epochs=epochs, batch_size=batch_size, - # max_text_length=max_text_length, - max_features=max_features, + max_text_length=max_text_length, generate_plots=generate_plots, save_model=save_model, save_path=save_path) diff --git a/libra/query/nlp_queries.py b/libra/query/nlp_queries.py index 8aad5db9..2d6f0388 100644 --- a/libra/query/nlp_queries.py +++ b/libra/query/nlp_queries.py @@ -90,7 +90,7 @@ def text_classification_query(self, instruction, drop=None, test_size=0.2, random_state=49, learning_rate=1e-2, - epochs=20, + epochs=5, monitor="val_loss", batch_size=32, max_text_length=20, @@ -114,6 +114,9 @@ def text_classification_query(self, instruction, drop=None, if epochs < 1: raise Exception("Epoch number is less than 1 (model will not be trained)") + if max_text_length <= 1: + raise Exception("Max text length should be larger than 1") + if batch_size < 1: raise Exception("Batch size must be equal to or greater than 1") diff --git a/tests/tests.py b/tests/tests.py index 6a42c76c..6c91fe9b 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -144,7 +144,7 @@ def test_captioning(self): # Tests whether text classification works without errors, and creates a key in models dictionary @ordered def test_text_classification(self): - x = client("/Users/anasawadalla/PycharmProjects/libra/tools/data/nlp_data/smallSentimentAnalysis.csv") + x = client("tools/data/nlp_data/smallSentimentAnalysis.csv") x.text_classification_query("get captions", epochs=1) # Tests whether name entity recognition query works without errors, and creates a key in models dictionary