diff --git a/prediction.py b/prediction.py index f3b38de..7e0975c 100644 --- a/prediction.py +++ b/prediction.py @@ -52,13 +52,18 @@ def main(): # Do forward pass for sentences provided through args model_output = model(args.sent1, args.sent2) - # Get the predicted label + # Get the predicted label + print("Model Output:") + print(model_output) y_hat = get_predicted_labels(model_output).item() print(f"Predicted class label: {mapIntToStr(y_hat)}") - print(model_output.shape) - print(model_output) - - + else: + # Do forward pass for sentences provided through args + model_output = model(args.sent1, args.sent2) + sentence1_embedding, sentence2_embedding = model_output + print(f"Shape of one sentence embedding: {sentence1_embedding.shape}") + print("Sentence Embedding:") + print(sentence2_embedding) if __name__ == "__main__": main()