Skip to content

Commit

Permalink
decreased original epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Sep 22, 2020
1 parent 4addd97 commit 1104323
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
8 changes: 3 additions & 5 deletions libra/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,11 +891,10 @@ def text_classification_query(self, instruction, label_column=None, drop=None,
test_size=0.2,
random_state=49,
learning_rate=1e-2,
epochs=20,
epochs=5,
monitor="val_loss",
batch_size=32,
# max_text_length=200,
max_features=2,
max_text_length=20,
generate_plots=True,
save_model=False,
save_path=os.getcwd()):
Expand Down Expand Up @@ -927,8 +926,7 @@ def text_classification_query(self, instruction, label_column=None, drop=None,
monitor=monitor,
epochs=epochs,
batch_size=batch_size,
# max_text_length=max_text_length,
max_features=max_features,
max_text_length=max_text_length,
generate_plots=generate_plots,
save_model=save_model,
save_path=save_path)
Expand Down
5 changes: 4 additions & 1 deletion libra/query/nlp_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def text_classification_query(self, instruction, drop=None,
test_size=0.2,
random_state=49,
learning_rate=1e-2,
epochs=20,
epochs=5,
monitor="val_loss",
batch_size=32,
max_text_length=20,
Expand All @@ -114,6 +114,9 @@ def text_classification_query(self, instruction, drop=None,
if epochs < 1:
raise Exception("Epoch number is less than 1 (model will not be trained)")

if max_text_length <= 1:
raise Exception("Max text length should be larger than 1")

if batch_size < 1:
raise Exception("Batch size must be equal to or greater than 1")

Expand Down
2 changes: 1 addition & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_captioning(self):
# Tests whether text classification works without errors, and creates a key in models dictionary
@ordered
def test_text_classification(self):
x = client("/Users/anasawadalla/PycharmProjects/libra/tools/data/nlp_data/smallSentimentAnalysis.csv")
x = client("tools/data/nlp_data/smallSentimentAnalysis.csv")
x.text_classification_query("get captions", epochs=1)

# Tests whether name entity recognition query works without errors, and creates a key in models dictionary
Expand Down

0 comments on commit 1104323

Please sign in to comment.