diff --git a/README.md b/README.md index 0402221..2ff9bbb 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,9 @@ from transformers_keras import Bert # Used to predict directly model = Bert.from_pretrained('/path/to/pretrained/bert/model') # segment_ids and mask inputs are optional -model.predict(inputs=(input_ids, segment_ids, mask)) +model.predict((input_ids, segment_ids, mask)) +# or +model(inputs=(input_ids, segment_ids, mask)) # Used to fine-tuning def build_bert_classify_model(trainable=True): @@ -70,7 +72,9 @@ from transformers_keras import Albert # Used to predict directly model = Bert.from_pretrained('/path/to/pretrained/albert/model') # segment_ids and mask inputs are optional -model(model.dummy_inputs()) +model.predict((input_ids, segment_ids, mask)) +# or +model(inputs=(input_ids, segment_ids, mask)) # Used to fine-tuning def build_albert_classify_model(trainable=True):