Skip to content

Commit

Permalink
Refractor createParseField function in mappers for code reusability
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed May 31, 2024
1 parent 180216f commit b6ae1f9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
Expand Down Expand Up @@ -540,6 +541,36 @@ private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMet
return knnMethodContext.getMethodComponentContext();
}

/**
* Function to create a vector field of type float. If the KNN field type is float we create a float vector field.
* @param array array of floats
* @param fieldType {@link FieldType}
* @return {@link VectorField}
*/
protected List<Field> getFieldsToBeAddedForFloatVector(final float[] array, final FieldType fieldType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
if (this.stored) {
fields.add(createStoredFieldForFloatVector(name(), array));
}
return fields;
}

/**
* Function to create a vector field of type byte. If the KNN field type is byte we create a float vector field.
* @param array array of bytes
* @param fieldType {@link FieldType}
* @return {@link VectorField}
*/
protected List<Field> getFieldsToBeAddedForByteVector(final byte[] array, final FieldType fieldType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
if (this.stored) {
fields.add(createStoredFieldForByteVector(name(), array));
}
return fields;
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

Expand All @@ -554,12 +585,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}
context.doc().addAll(getFieldsToBeAddedForByteVector(array, fieldType));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -568,11 +594,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}
context.doc().addAll(getFieldsToBeAddedForFloatVector(array, fieldType));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
package org.opensearch.knn.index.mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.common.Explicit;
Expand Down Expand Up @@ -76,55 +80,46 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}
}

/**
* Function to create a vector field of type float. If the KNN field type is float we create a float vector field.
* @param array array of floats
* @param fieldType {@link FieldType}
* @return {@link KnnFloatVectorField}
*/
@Override
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);
if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
);
protected List<Field> getFieldsToBeAddedForFloatVector(final float[] array, final FieldType fieldType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType));

if (hasDocValues && vectorFieldType != null) {
fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType));
}

if (this.stored) {
fieldsToBeAdded.add(createStoredFieldForFloatVector(name(), array));
}
return fieldsToBeAdded;
}

context.path().remove();
/**
* Function to create a vector field of type byte. If the KNN field type is float we create a byte vector field.
* @param array array of bytes
* @param fieldType {@link FieldType}
* @return {@link KnnByteVectorField}
*/
@Override
protected List<Field> getFieldsToBeAddedForByteVector(final byte[] array, final FieldType fieldType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType));

if (hasDocValues && vectorFieldType != null) {
fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType));
}

if (this.stored) {
fieldsToBeAdded.add(createStoredFieldForByteVector(name(), array));
}
return fieldsToBeAdded;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
Expand Down

0 comments on commit b6ae1f9

Please sign in to comment.