Skip to content

Commit

Permalink
wire hex encoding for tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
arnej27959 committed Feb 1, 2025
1 parent 172a3ae commit 08d2fa6
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 80 deletions.
3 changes: 3 additions & 0 deletions container-search/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -5443,6 +5443,8 @@
"public java.util.Set getSummaryFields()",
"public void setSummaryFields(java.lang.String)",
"public boolean getTensorShortForm()",
"public boolean getTensorHexDense()",
"public java.lang.String getTensorFormat()",
"public void setTensorShortForm(java.lang.String)",
"public void setTensorFormat(java.lang.String)",
"public void setTensorShortForm(boolean)",
Expand Down Expand Up @@ -8067,6 +8069,7 @@
"public java.lang.String toJson()",
"public java.lang.String toJson(boolean)",
"public java.lang.String toJson(boolean, boolean)",
"public java.lang.String toJson(com.yahoo.tensor.serialization.JsonFormat$EncodeOptions)",
"public java.lang.StringBuilder writeJson(java.lang.StringBuilder)",
"public java.lang.Double getDouble(java.lang.String)",
"public com.yahoo.tensor.Tensor getTensor(java.lang.String)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public class Presentation implements Cloneable {
/** Whether to renders tensors in short form */
private boolean tensorDirectValues = false; // TODO: Flip default on Vespa 9

/** Whether to dense (part of) tensors in hex string form */
private boolean tensorHexDense = false;

/** Set of explicitly requested summary fields, instead of summary classes */
private Set<String> summaryFields = LazySet.newHashSet();

Expand Down Expand Up @@ -186,6 +189,20 @@ public void setSummaryFields(String asString) {
*/
public boolean getTensorShortForm() { return tensorShortForm; }

/** whether dense part of tensors should be represented as a string of hex digits */
public boolean getTensorHexDense() { return tensorHexDense; }

/** the current tensor format, see setTensorFormat() */
public String getTensorFormat() {
String format = "long";
if (tensorShortForm) format = "short";
if (tensorHexDense) format = "hex";
if (tensorDirectValues) {
return (format + "-value");
}
return format;
}

/** @deprecated use setTensorFormat(). */
@Deprecated // TODO: Remove on Vespa 9
public void setTensorShortForm(String value) {
Expand All @@ -199,6 +216,16 @@ public void setTensorShortForm(String value) {
*/
public void setTensorFormat(String value) {
switch (value) {
case "hex" :
tensorHexDense = true;
tensorShortForm = true;
tensorDirectValues = false;
break;
case "hex-value" :
tensorHexDense = true;
tensorShortForm = true;
tensorDirectValues = true;
break;
case "short" :
tensorShortForm = true;
tensorDirectValues = false;
Expand Down Expand Up @@ -254,4 +281,3 @@ public int hashCode() {
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private static Map<CompoundName, GetterSetter> createPropertySetterMap() {
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT), GetterSetter.of(query -> query.getPresentation().getFormat(), (query, value) -> query.getPresentation().setFormat(asString(value, ""))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.TIMING), GetterSetter.of(query -> query.getPresentation().getTiming(), (query, value) -> query.getPresentation().setTiming(asBoolean(value, true))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.SUMMARY_FIELDS), GetterSetter.of(query -> query.getPresentation().getSummaryFields(), (query, value) -> query.getPresentation().setSummaryFields(asString(value, ""))));
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT, Presentation.TENSORS), GetterSetter.of(query -> query.getPresentation().getTensorShortForm(), (query, value) -> query.getPresentation().setTensorFormat(asString(value, "short")))); // TODO: Switch default to short-value on Vespa 9);
map.put(CompoundName.fromComponents(Presentation.PRESENTATION, Presentation.FORMAT, Presentation.TENSORS), GetterSetter.of(query -> query.getPresentation().getTensorFormat(), (query, value) -> query.getPresentation().setTensorFormat(asString(value, "short")))); // TODO: Switch default to short-value on Vespa 9);
map.put(Query.HITS, GetterSetter.of(Query::getHits, (query, value) -> query.setHits(asInteger(value,10))));
map.put(Query.OFFSET, GetterSetter.of(Query::getOffset, (query, value) -> query.setOffset(asInteger(value,0))));
map.put(Query.TIMEOUT, GetterSetter.of(Query::getTimeout, (query, value) -> query.setTimeout(value.toString())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,15 @@ static class FieldConsumerSettings {
volatile boolean jsonWsets = true;
volatile boolean jsonMapsAll = true;
volatile boolean jsonWsetsAll = false;
volatile boolean tensorShortForm = true;
volatile boolean tensorDirectValues = false;
volatile JsonFormat.EncodeOptions tensorOptions;
boolean convertDeep() { return (jsonDeepMaps || jsonWsets); }
void init() {
this.debugRendering = false;
this.jsonDeepMaps = true;
this.jsonWsets = true;
this.jsonMapsAll = true;
this.jsonWsetsAll = true;
this.tensorShortForm = true;
this.tensorDirectValues = false;
this.tensorOptions = new JsonFormat.EncodeOptions(true, false, false);
}
void getSettings(Query q) {
if (q == null) {
Expand All @@ -156,9 +154,11 @@ void getSettings(Query q) {
// we may need more fine tuning, but for now use the same query parameters here:
this.jsonMapsAll = props.getBoolean(WRAP_DEEP_MAPS, true);
this.jsonWsetsAll = props.getBoolean(WRAP_WSETS, true);
this.tensorShortForm = q.getPresentation().getTensorShortForm();
this.tensorDirectValues = q.getPresentation().getTensorDirectValues();
}
this.tensorOptions = new JsonFormat.EncodeOptions(
q.getPresentation().getTensorShortForm(),
q.getPresentation().getTensorDirectValues(),
q.getPresentation().getTensorHexDense());
}
}

private volatile FieldConsumerSettings fieldConsumerSettings;
Expand Down Expand Up @@ -560,14 +560,16 @@ public static class FieldConsumer implements Hit.RawUtf8Consumer, TraceRenderer.

/** Invoke this from your constructor when sub-classing {@link FieldConsumer} */
protected FieldConsumer(boolean debugRendering, boolean tensorShortForm, boolean jsonMaps) {
this(null, debugRendering, tensorShortForm, jsonMaps);
this(null, debugRendering, new JsonFormat.EncodeOptions(tensorShortForm, false, false), jsonMaps);
}

private FieldConsumer(JsonGenerator generator, boolean debugRendering, boolean tensorShortForm, boolean jsonMaps) {
private FieldConsumer(JsonGenerator generator, boolean debugRendering,
JsonFormat.EncodeOptions tensorOptions,
boolean jsonMaps) {
this.generator = generator;
this.settings = new FieldConsumerSettings();
this.settings.debugRendering = debugRendering;
this.settings.tensorShortForm = tensorShortForm;
this.settings.tensorOptions = tensorOptions;
this.settings.jsonDeepMaps = jsonMaps;
}

Expand Down Expand Up @@ -768,27 +770,27 @@ protected void renderFieldContents(Object field) throws IOException {
public void accept(Object field) throws IOException {
if (field == null) {
generator().writeNull();
} else if (field instanceof Boolean) {
generator().writeBoolean((Boolean)field);
} else if (field instanceof Number) {
renderNumberField((Number) field);
} else if (field instanceof TreeNode) {
generator().writeTree((TreeNode) field);
} else if (field instanceof Tensor) {
renderTensor(Optional.of((Tensor)field));
} else if (field instanceof FeatureData) {
generator().writeRawValue(((FeatureData)field).toJson(settings.tensorShortForm, settings.tensorDirectValues));
} else if (field instanceof Inspectable) {
renderInspectorDirect(((Inspectable)field).inspect());
} else if (field instanceof JsonProducer) {
generator().writeRawValue(((JsonProducer) field).toJson());
} else if (field instanceof StringFieldValue) {
generator().writeString(((StringFieldValue)field).getString());
} else if (field instanceof TensorFieldValue) {
renderTensor(((TensorFieldValue)field).getTensor());
} else if (field instanceof FieldValue) {
// the null below is the field which has already been written
((FieldValue) field).serialize(null, new JsonWriter(generator));
} else if (field instanceof Boolean bool) {
generator().writeBoolean(bool);
} else if (field instanceof Number num) {
renderNumberField(num);
} else if (field instanceof TreeNode treenode) {
generator().writeTree(treenode);
} else if (field instanceof Tensor t) {
renderTensor(Optional.of(t));
} else if (field instanceof FeatureData featureData) {
generator().writeRawValue(featureData.toJson(settings.tensorOptions));
} else if (field instanceof Inspectable i) {
renderInspectorDirect(i.inspect());
} else if (field instanceof JsonProducer jp) {
generator().writeRawValue(jp.toJson());
} else if (field instanceof StringFieldValue sfv) {
generator().writeString(sfv.getString());
} else if (field instanceof TensorFieldValue tfv) {
renderTensor(tfv.getTensor());
} else if (field instanceof FieldValue fv) {
// the null below is the field name which has already been written
fv.serialize(null, new JsonWriter(generator));
} else {
generator().writeString(field.toString());
}
Expand All @@ -797,27 +799,27 @@ public void accept(Object field) throws IOException {
private void renderNumberField(Number field) throws IOException {
if (field instanceof Integer) {
generator().writeNumber(field.intValue());
} else if (field instanceof Float) {
} else if (field instanceof Float) {
generator().writeNumber(field.floatValue());
} else if (field instanceof Double) {
} else if (field instanceof Double) {
generator().writeNumber(field.doubleValue());
} else if (field instanceof Long) {
generator().writeNumber(field.longValue());
} else if (field instanceof Byte || field instanceof Short) {
generator().writeNumber(field.intValue());
} else if (field instanceof BigInteger) {
generator().writeNumber((BigInteger) field);
} else if (field instanceof BigDecimal) {
generator().writeNumber((BigDecimal) field);
} else if (field instanceof BigInteger bigint) {
generator().writeNumber(bigint);
} else if (field instanceof BigDecimal bigdec) {
generator().writeNumber(bigdec);
} else {
generator().writeNumber(field.doubleValue());
}
}

private void renderTensor(Optional<Tensor> tensor) throws IOException {
generator().writeRawValue(new String(JsonFormat.encode(tensor.orElse(Tensor.Builder.of(TensorType.empty).build()),
settings.tensorShortForm, settings.tensorDirectValues),
StandardCharsets.UTF_8));
var t = tensor.orElse(Tensor.Builder.of(TensorType.empty).build());
byte[] json = JsonFormat.encode(t, settings.tensorOptions);
generator().writeRawValue(new String(json, StandardCharsets.UTF_8));
}

private JsonGenerator generator() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ public class FeatureData implements Inspectable, JsonProducer {
/** The lazily computed feature names of this */
private Set<String> featureNames = null;

/** The lazily computed json form of this */
private String jsonForm = null;

public FeatureData(Inspector encodedValues) {
this.encodedValues = Objects.requireNonNull(encodedValues);
}
Expand All @@ -71,40 +68,43 @@ public Inspector inspect() {

@Override
public String toJson() {
return toJson(false, false);
return toJson(new JsonFormat.EncodeOptions(false, false, false));
}

public String toJson(boolean tensorShortForm) {
return toJson(tensorShortForm, false);
return toJson(new JsonFormat.EncodeOptions(tensorShortForm, false, false));
}

public String toJson(boolean tensorShortForm, boolean tensorDirectValues) {
return writeJson(tensorShortForm, tensorDirectValues, new StringBuilder()).toString();
return toJson(new JsonFormat.EncodeOptions(tensorShortForm, tensorDirectValues, false));
}

public String toJson(JsonFormat.EncodeOptions tensorOptions) {
return writeJson(tensorOptions, new StringBuilder()).toString();
}

@Override
public StringBuilder writeJson(StringBuilder target) {
return JsonRender.render(encodedValues, new Encoder(target, true, false, false));
return writeJson(new JsonFormat.EncodeOptions(false, false, false), target);
}

private StringBuilder writeJson(boolean tensorShortForm, boolean tensorDirectValues, StringBuilder target) {
private StringBuilder writeJson(JsonFormat.EncodeOptions tensorOptions, StringBuilder target) {
if (this == empty) return target.append("{}");
if (jsonForm != null) return target.append(jsonForm);

if (encodedValues != null)
return JsonRender.render(encodedValues, new Encoder(target, true, tensorShortForm, tensorDirectValues));
return JsonRender.render(encodedValues, new Encoder(target, true, tensorOptions));
else
return writeJson(values, tensorShortForm, tensorDirectValues, target);
return writeJson(values, tensorOptions, target);
}

private StringBuilder writeJson(Map<String, Tensor> values, boolean tensorShortForm, boolean tensorDirectValues, StringBuilder target) {
private StringBuilder writeJson(Map<String, Tensor> values, JsonFormat.EncodeOptions tensorOptions, StringBuilder target) {
target.append("{");
for (Map.Entry<String, Tensor> entry : values.entrySet()) {
target.append("\"").append(entry.getKey()).append("\":");
if (entry.getValue().type().rank() == 0) {
target.append(entry.getValue().asDouble());
} else {
byte[] encodedTensor = JsonFormat.encode(entry.getValue(), tensorShortForm, tensorDirectValues);
byte[] encodedTensor = JsonFormat.encode(entry.getValue(), tensorOptions);
target.append(new String(encodedTensor, StandardCharsets.UTF_8));
}
target.append(",");
Expand Down Expand Up @@ -149,7 +149,7 @@ private Tensor decodeTensor(String featureName) {

return switch (featureValue.type()) {
case DOUBLE -> Tensor.from(featureValue.asDouble());
case DATA -> TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(featureValue.asData()));
case DATA -> tensorFromData(featureValue.asData());
default -> throw new IllegalStateException("Unexpected feature value type " + featureValue.type());
};
}
Expand Down Expand Up @@ -192,23 +192,24 @@ public boolean equals(Object other) {
/** A JSON encoder which encodes DATA as a tensor */
private static class Encoder extends JsonRender.StringEncoder {

private final boolean tensorShortForm;
private final boolean tensorDirectValues;
private final JsonFormat.EncodeOptions tensorOptions;

Encoder(StringBuilder out, boolean compact, boolean tensorShortForm, boolean tensorDirectValues) {
Encoder(StringBuilder out, boolean compact, JsonFormat.EncodeOptions tensorOptions) {
super(out, compact);
this.tensorShortForm = tensorShortForm;
this.tensorDirectValues = tensorDirectValues;
this.tensorOptions = tensorOptions;
}

@Override
public void encodeDATA(byte[] value) {
// This could be done more efficiently ...
Tensor tensor = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(value));
byte[] encodedTensor = JsonFormat.encode(tensor, tensorShortForm, tensorDirectValues);
Tensor tensor = tensorFromData(value);
byte[] encodedTensor = JsonFormat.encode(tensor, tensorOptions);
target().append(new String(encodedTensor, StandardCharsets.UTF_8));
}

}

private static Tensor tensorFromData(byte[] value) {
return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(value));
}
}
Loading

0 comments on commit 08d2fa6

Please sign in to comment.