From ee2ecb3a5e8c2af2c4f721df9029e79e720dd188 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 6 Oct 2024 13:04:51 -0700 Subject: [PATCH] [api] Fixes QaServingTranslator output format and TokenClassification crash --- .../nlp/translator/QaServingTranslator.java | 16 +++++++++++++--- .../TokenClassificationServingTranslator.java | 1 + .../ai/djl/ndarray/BytesSupplierImpl.java | 4 ++-- .../tokenizers/HuggingFaceTokenizer.java | 19 +++++++++++++++++++ .../tokenizers/HuggingFaceTokenizerTest.java | 12 ++++++++++++ .../QuestionAnsweringTranslatorTest.java | 2 +- 6 files changed, 48 insertions(+), 6 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java index 45e3f59afba..7224a6d4051 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java @@ -30,14 +30,17 @@ import com.google.gson.reflect.TypeToken; import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; /** * A {@link Translator} that can handle generic question answering {@link Input} and {@link Output}. */ public class QaServingTranslator implements NoBatchifyTranslator { - private static final Type LIST_TYPE = new TypeToken>() {}.getType(); + private static final Type LIST_TYPE = new TypeToken>() {}.getType(); private Translator translator; @@ -116,13 +119,20 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception Output output = new Output(); output.addProperty("Content-Type", "application/json"); if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); + List answers = translator.batchProcessOutput(ctx, list); + List> ret = new ArrayList<>(); + for (String answer : answers) { + ret.add(Collections.singletonMap("answer", answer)); + } + output.add(BytesSupplier.wrapAsJson(ret)); } else { Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { list = batchifier.unbatchify(list)[0]; } - output.add(translator.processOutput(ctx, list)); + String answer = translator.processOutput(ctx, list); + Map ret = Collections.singletonMap("answer", answer); + output.add(BytesSupplier.wrapAsJson(ret)); } return output; } diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java index a6ab6e8af20..f6b1fd38663 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java @@ -56,6 +56,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception TextPrompt prompt = TextPrompt.parseInput(input); if (prompt.isBatch()) { ctx.setAttachment("batch", Boolean.TRUE); + return translator.batchProcessInput(ctx, prompt.getBatch()); } NDList ret = translator.processInput(ctx, prompt.getText()); diff --git a/api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java b/api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java index d3f7198c103..e6af90ebb13 100644 --- a/api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java +++ b/api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java @@ -40,7 +40,7 @@ class BytesSupplierImpl implements BytesSupplier { public byte[] getAsBytes() { if (buf == null) { if (value == null) { - value = JsonUtils.toJson(obj) + '\n'; + value = JsonUtils.toJson(obj); } buf = value.getBytes(StandardCharsets.UTF_8); } @@ -52,7 +52,7 @@ public byte[] getAsBytes() { public String getAsString() { if (value == null) { if (obj != null) { - value = JsonUtils.toJson(obj) + '\n'; + value = JsonUtils.toJson(obj); } else { value = new String(buf, StandardCharsets.UTF_8); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index c10e7c3e6ec..cd6eaf289ca 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; /** @@ -232,6 +233,9 @@ public void close() { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (text == null) { + throw new NullPointerException("text cannot be null"); + } if (doLowerCase != null) { text = text.toLowerCase(doLowerCase); } @@ -261,6 +265,10 @@ public Encoding encode(String text) { */ public Encoding encode( String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) { + if (text == null || textPair == null) { + throw new NullPointerException("text/text_pair cannot be null"); + } + if (doLowerCase != null) { text = text.toLowerCase(doLowerCase); textPair = textPair.toLowerCase(doLowerCase); @@ -322,6 +330,8 @@ public Encoding encode( for (int i = 0; i < inputs.length; ++i) { inputs[i] = inputs[i].toLowerCase(doLowerCase); } + } else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) { + throw new NullPointerException("input text cannot be null"); } long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens); return toEncoding(encoding, withOverflowingTokens); @@ -377,6 +387,8 @@ public Encoding[] batchEncode( for (int i = 0; i < inputs.length; ++i) { inputs[i] = inputs[i].toLowerCase(doLowerCase); } + } else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) { + throw new NullPointerException("input text cannot be null"); } long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; @@ -418,6 +430,13 @@ public Encoding[] batchEncode( for (int i = 0; i < textPair.length; ++i) { textPair[i] = textPair[i].toLowerCase(doLowerCase); } + } else { + if (inputs.keys().stream().anyMatch(Objects::isNull)) { + throw new NullPointerException("text pair key cannot be null"); + } + if (inputs.values().stream().anyMatch(Objects::isNull)) { + throw new NullPointerException("text pair value cannot be null"); + } } long[] encodings = TokenizersLibrary.LIB.batchEncodePair( diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 5c1e6387313..27fb1056b7c 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -27,6 +27,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -102,6 +103,17 @@ public void testTokenizer() throws IOException { Assert.assertEquals(charSpansExpected[i].getStart(), charSpansResult[i].getStart()); Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd()); } + + Assert.assertThrows(() -> tokenizer.encode((String) null)); + Assert.assertThrows(() -> tokenizer.encode(new String[] {null})); + Assert.assertThrows(() -> tokenizer.encode(null, null)); + Assert.assertThrows(() -> tokenizer.encode("null", null)); + Assert.assertThrows(() -> tokenizer.batchEncode(new String[] {null})); + List empty = Collections.singletonList(null); + List some = Collections.singletonList("null"); + + Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(empty, some))); + Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(some, empty))); } Map options = new ConcurrentHashMap<>(); diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/QuestionAnsweringTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/QuestionAnsweringTranslatorTest.java index acb283734d1..b67d2a7bcd3 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/QuestionAnsweringTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/QuestionAnsweringTranslatorTest.java @@ -102,7 +102,7 @@ public void testQATranslator() throws ModelException, IOException, TranslateExce input.add("question", question); input.add("paragraph", paragraph); Output res = predictor.predict(input); - Assert.assertEquals(res.getAsString(0), "December 2004"); + Assert.assertEquals(res.getAsString(0), "{\"answer\":\"December 2004\"}"); Assert.assertThrows( "Input data is empty.",