Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Fixes QaServingTranslator output format #3500

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@
import com.google.gson.reflect.TypeToken;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
* A {@link Translator} that can handle generic question answering {@link Input} and {@link Output}.
*/
public class QaServingTranslator implements NoBatchifyTranslator<Input, Output> {

private static final Type LIST_TYPE = new TypeToken<List<String>>() {}.getType();
private static final Type LIST_TYPE = new TypeToken<List<QAInput>>() {}.getType();

private Translator<QAInput, String> translator;

Expand Down Expand Up @@ -116,13 +119,20 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
List<String> answers = translator.batchProcessOutput(ctx, list);
List<Map<String, String>> ret = new ArrayList<>();
for (String answer : answers) {
ret.add(Collections.singletonMap("answer", answer));
}
output.add(BytesSupplier.wrapAsJson(ret));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(translator.processOutput(ctx, list));
String answer = translator.processOutput(ctx, list);
Map<String, String> ret = Collections.singletonMap("answer", answer);
output.add(BytesSupplier.wrapAsJson(ret));
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
TextPrompt prompt = TextPrompt.parseInput(input);
if (prompt.isBatch()) {
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, prompt.getBatch());
}

NDList ret = translator.processInput(ctx, prompt.getText());
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BytesSupplierImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BytesSupplierImpl implements BytesSupplier {
public byte[] getAsBytes() {
if (buf == null) {
if (value == null) {
value = JsonUtils.toJson(obj) + '\n';
value = JsonUtils.toJson(obj);
}
buf = value.getBytes(StandardCharsets.UTF_8);
}
Expand All @@ -52,7 +52,7 @@ public byte[] getAsBytes() {
public String getAsString() {
if (value == null) {
if (obj != null) {
value = JsonUtils.toJson(obj) + '\n';
value = JsonUtils.toJson(obj);
} else {
value = new String(buf, StandardCharsets.UTF_8);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
Expand Down Expand Up @@ -232,6 +233,9 @@ public void close() {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (text == null) {
throw new NullPointerException("text cannot be null");
}
if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
}
Expand Down Expand Up @@ -261,6 +265,10 @@ public Encoding encode(String text) {
*/
public Encoding encode(
String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
if (text == null || textPair == null) {
throw new NullPointerException("text/text_pair cannot be null");
}

if (doLowerCase != null) {
text = text.toLowerCase(doLowerCase);
textPair = textPair.toLowerCase(doLowerCase);
Expand Down Expand Up @@ -322,6 +330,8 @@ public Encoding encode(
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
} else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) {
throw new NullPointerException("input text cannot be null");
}
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding, withOverflowingTokens);
Expand Down Expand Up @@ -377,6 +387,8 @@ public Encoding[] batchEncode(
for (int i = 0; i < inputs.length; ++i) {
inputs[i] = inputs[i].toLowerCase(doLowerCase);
}
} else if (Arrays.stream(inputs).anyMatch(Objects::isNull)) {
throw new NullPointerException("input text cannot be null");
}
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
Expand Down Expand Up @@ -418,6 +430,13 @@ public Encoding[] batchEncode(
for (int i = 0; i < textPair.length; ++i) {
textPair[i] = textPair[i].toLowerCase(doLowerCase);
}
} else {
if (inputs.keys().stream().anyMatch(Objects::isNull)) {
throw new NullPointerException("text pair key cannot be null");
}
if (inputs.values().stream().anyMatch(Objects::isNull)) {
throw new NullPointerException("text pair value cannot be null");
}
}
long[] encodings =
TokenizersLibrary.LIB.batchEncodePair(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -102,6 +103,17 @@ public void testTokenizer() throws IOException {
Assert.assertEquals(charSpansExpected[i].getStart(), charSpansResult[i].getStart());
Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd());
}

Assert.assertThrows(() -> tokenizer.encode((String) null));
Assert.assertThrows(() -> tokenizer.encode(new String[] {null}));
Assert.assertThrows(() -> tokenizer.encode(null, null));
Assert.assertThrows(() -> tokenizer.encode("null", null));
Assert.assertThrows(() -> tokenizer.batchEncode(new String[] {null}));
List<String> empty = Collections.singletonList(null);
List<String> some = Collections.singletonList("null");

Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(empty, some)));
Assert.assertThrows(() -> tokenizer.batchEncode(new PairList<>(some, empty)));
}

Map<String, String> options = new ConcurrentHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public void testQATranslator() throws ModelException, IOException, TranslateExce
input.add("question", question);
input.add("paragraph", paragraph);
Output res = predictor.predict(input);
Assert.assertEquals(res.getAsString(0), "December 2004");
Assert.assertEquals(res.getAsString(0), "{\"answer\":\"December 2004\"}");

Assert.assertThrows(
"Input data is empty.",
Expand Down
Loading