From 849027bf2769e5d05cfd9bee65d15ca2bbacf799 Mon Sep 17 00:00:00 2001 From: Shinsuke Sugaya Date: Sat, 22 Jun 2024 09:44:13 +0900 Subject: [PATCH] add query support --- .../fess/multimodal/MultiModalConstants.java | 11 + .../fess/multimodal/client/CasClient.java | 21 ++ .../crawler/extractor/CasExtractor.java | 8 +- .../index/query/KNNQueryBuilder.java | 190 +++++++++++++++ .../multimodal/ingest/EmbeddingIngester.java | 24 +- .../query/MultiModalPhraseQueryCommand.java | 60 +++++ .../query/MultiModalQueryBuilder.java | 64 +++++ .../query/MultiModalTermQueryCommand.java | 59 +++++ .../rank/fusion/MultiModalSearcher.java | 224 ++++++++++++++++++ .../fess_query+phraseQueryCommand.xml | 9 + .../resources/fess_query+termQueryCommand.xml | 9 + src/main/resources/fess_rankfusion++.xml | 8 + .../fess/multimodal/client/CasClientTest.java | 11 + .../ingest/EmbeddingIngesterTest.java | 12 +- 14 files changed, 688 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/codelibs/fess/multimodal/index/query/KNNQueryBuilder.java create mode 100644 src/main/java/org/codelibs/fess/multimodal/query/MultiModalPhraseQueryCommand.java create mode 100644 src/main/java/org/codelibs/fess/multimodal/query/MultiModalQueryBuilder.java create mode 100644 src/main/java/org/codelibs/fess/multimodal/query/MultiModalTermQueryCommand.java create mode 100644 src/main/java/org/codelibs/fess/multimodal/rank/fusion/MultiModalSearcher.java create mode 100644 src/main/resources/fess_query+phraseQueryCommand.xml create mode 100644 src/main/resources/fess_query+termQueryCommand.xml create mode 100644 src/main/resources/fess_rankfusion++.xml diff --git a/src/main/java/org/codelibs/fess/multimodal/MultiModalConstants.java b/src/main/java/org/codelibs/fess/multimodal/MultiModalConstants.java index 60893ff..7c94335 100644 --- a/src/main/java/org/codelibs/fess/multimodal/MultiModalConstants.java +++ b/src/main/java/org/codelibs/fess/multimodal/MultiModalConstants.java @@ -16,8 +16,19 @@ package org.codelibs.fess.multimodal; public class MultiModalConstants { + + private static final String PREFIX = "fess.multimodal."; + + public static final String MIN_SCORE = PREFIX + "min_score"; + + public static final String CONTENT_VECTOR_FIELD = System.getProperty(PREFIX + "content.field", "content_vector"); + public static final String X_FESS_EMBEDDING = "X-FESS-Embedding"; + public static final String SEARCHER = "multiModalSearcher"; + + public static final String CAS_CLIENT = "casClient"; + private MultiModalConstants() { // nothing } diff --git a/src/main/java/org/codelibs/fess/multimodal/client/CasClient.java b/src/main/java/org/codelibs/fess/multimodal/client/CasClient.java index bdfa7f2..209098b 100644 --- a/src/main/java/org/codelibs/fess/multimodal/client/CasClient.java +++ b/src/main/java/org/codelibs/fess/multimodal/client/CasClient.java @@ -153,4 +153,25 @@ protected String encodeImage(final InputStream in) { throw new CasAccessException("Failed to read an image.", e); } } + + public float[] getTextEmbedding(final String query) { + final String body = "{\"data\":[{\"text\":\"" + StringEscapeUtils.escapeJson(query) + "\"}],\"execEndpoint\":\"/\"}"; + logger.debug("request body: {}", body); + try (CurlResponse response = Curl.post(clipEndpoint + "/post").header("Content-Type", "application/json").body(body).execute()) { + final Map contentMap = response.getContent(PARSER); + if (((contentMap.get("data") instanceof final List dataList) + && (!dataList.isEmpty() && dataList.get(0) instanceof final Map data)) + && (data.get("embedding") instanceof final List embeddingList)) { + logger.debug("embedding: {}", embeddingList); + final float[] embedding = new float[embeddingList.size()]; + for (int i = 0; i < embedding.length; i++) { + embedding[i] = ((Number) embeddingList.get(i)).floatValue(); + } + return embedding; + } + } catch (final IOException e) { + throw new CasAccessException("Clip server failed to generate an embedding.", e); + } + throw new CasAccessException("Clip server cannot generate an embedding"); + } } diff --git a/src/main/java/org/codelibs/fess/multimodal/crawler/extractor/CasExtractor.java b/src/main/java/org/codelibs/fess/multimodal/crawler/extractor/CasExtractor.java index d0dcb50..259cce7 100644 --- a/src/main/java/org/codelibs/fess/multimodal/crawler/extractor/CasExtractor.java +++ b/src/main/java/org/codelibs/fess/multimodal/crawler/extractor/CasExtractor.java @@ -15,6 +15,9 @@ */ package org.codelibs.fess.multimodal.crawler.extractor; +import static org.codelibs.fess.multimodal.MultiModalConstants.CAS_CLIENT; +import static org.codelibs.fess.multimodal.MultiModalConstants.X_FESS_EMBEDDING; + import java.io.InputStream; import java.util.Map; @@ -24,7 +27,6 @@ import org.apache.logging.log4j.Logger; import org.codelibs.fess.crawler.entity.ExtractData; import org.codelibs.fess.crawler.extractor.impl.TikaExtractor; -import org.codelibs.fess.multimodal.MultiModalConstants; import org.codelibs.fess.multimodal.client.CasClient; import org.codelibs.fess.multimodal.ingest.EmbeddingIngester; import org.codelibs.fess.multimodal.util.EmbeddingUtil; @@ -45,14 +47,14 @@ public int getWeight() { public void init() { super.init(); - client = crawlerContainer.getComponent("casClient"); + client = crawlerContainer.getComponent(CAS_CLIENT); } @Override public ExtractData getText(final InputStream inputStream, final Map params) { return getText(inputStream, params, (data, in) -> { try { - data.putValue(MultiModalConstants.X_FESS_EMBEDDING, EmbeddingUtil.encodeFloatArray(client.getImageEmbedding(in))); + data.putValue(X_FESS_EMBEDDING, EmbeddingUtil.encodeFloatArray(client.getImageEmbedding(in))); } catch (final Exception e) { logger.warn("Failed to convert an image to a vector.", e); } diff --git a/src/main/java/org/codelibs/fess/multimodal/index/query/KNNQueryBuilder.java b/src/main/java/org/codelibs/fess/multimodal/index/query/KNNQueryBuilder.java new file mode 100644 index 0000000..94f1d25 --- /dev/null +++ b/src/main/java/org/codelibs/fess/multimodal/index/query/KNNQueryBuilder.java @@ -0,0 +1,190 @@ +/* + * Copyright 2012-2024 CodeLibs Project and the Others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language + * governing permissions and limitations under the License. + */ +package org.codelibs.fess.multimodal.index.query; + +import java.io.IOException; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; + +public class KNNQueryBuilder extends AbstractQueryBuilder { + + private static final String NAME = "knn"; + + private static final ParseField VECTOR_FIELD = new ParseField("vector"); + private static final ParseField K_FIELD = new ParseField("k"); + private static final ParseField FILTER_FIELD = new ParseField("filter"); + private static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); + private static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + private static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + + private static final int DEFAULT_K = 10; + + protected String fieldName; + + protected float[] vector; + protected int k; + protected QueryBuilder filter; + protected boolean ignoreUnmapped; + protected Float maxDistance; + protected Float minScore; + + public KNNQueryBuilder(final StreamInput in) throws IOException { + super(in); + this.fieldName = in.readString(); + this.vector = in.readFloatArray(); + this.k = in.readInt(); + this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); + this.ignoreUnmapped = in.readBoolean(); + this.maxDistance = in.readOptionalFloat(); + this.minScore = in.readOptionalFloat(); + } + + private KNNQueryBuilder() { + } + + public static class Builder { + private String fieldName; + private float[] vector; + private int k = DEFAULT_K; + private QueryBuilder filter; + private boolean ignoreUnmapped = false; + private Float maxDistance = null; + private Float minScore = null; + + public Builder field(final String fieldName) { + this.fieldName = fieldName; + return this; + } + + public Builder vector(final float[] vector) { + this.vector = vector; + return this; + } + + public Builder k(final int k) { + this.k = k; + return this; + } + + public Builder filter(final QueryBuilder filter) { + this.filter = filter; + return this; + } + + public Builder ignoreUnmapped(final boolean ignoreUnmapped) { + this.ignoreUnmapped = ignoreUnmapped; + return this; + } + + public Builder maxDistance(final Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + + public Builder minScore(final Float minScore) { + this.minScore = minScore; + return this; + } + + public KNNQueryBuilder build() { + final KNNQueryBuilder query = new KNNQueryBuilder(); + query.fieldName = fieldName; + query.vector = vector; + query.k = k; + query.filter = filter; + query.ignoreUnmapped = ignoreUnmapped; + query.maxDistance = maxDistance; + query.minScore = minScore; + return query; + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected void doWriteTo(final StreamOutput out) throws IOException { + out.writeString(this.fieldName); + out.writeFloatArray(this.vector); + out.writeInt(this.k); + out.writeOptionalNamedWriteable(this.filter); + out.writeBoolean(this.ignoreUnmapped); + out.writeOptionalFloat(this.maxDistance); + out.writeOptionalFloat(this.minScore); + } + + @Override + protected void doXContent(final XContentBuilder xContentBuilder, final Params params) throws IOException { + xContentBuilder.startObject(NAME); + xContentBuilder.startObject(fieldName); + xContentBuilder.field(VECTOR_FIELD.getPreferredName(), vector); + xContentBuilder.field(K_FIELD.getPreferredName(), k); + if (filter != null) { + xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); + } + xContentBuilder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); + if (maxDistance != null) { + xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance); + } + if (minScore != null) { + xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } + printBoostAndQueryName(xContentBuilder); + xContentBuilder.endObject(); + xContentBuilder.endObject(); + } + + @Override + protected Query doToQuery(final QueryShardContext context) throws IOException { + throw new UnsupportedOperationException("doToQuery is not supported."); + } + + @Override + protected boolean doEquals(final KNNQueryBuilder obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + final EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(fieldName, obj.fieldName); + equalsBuilder.append(vector, obj.vector); + equalsBuilder.append(k, obj.k); + equalsBuilder.append(filter, obj.filter); + equalsBuilder.append(ignoreUnmapped, obj.ignoreUnmapped); + equalsBuilder.append(maxDistance, obj.maxDistance); + equalsBuilder.append(minScore, obj.minScore); + return equalsBuilder.isEquals(); + } + + @Override + protected int doHashCode() { + return new HashCodeBuilder().append(fieldName).append(vector).append(k).append(filter).append(ignoreUnmapped).append(maxDistance) + .append(minScore).toHashCode(); + } +} diff --git a/src/main/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngester.java b/src/main/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngester.java index 75c12cb..6ef3cd7 100644 --- a/src/main/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngester.java +++ b/src/main/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngester.java @@ -15,42 +15,40 @@ */ package org.codelibs.fess.multimodal.ingest; +import static org.codelibs.core.lang.StringUtil.EMPTY; +import static org.codelibs.fess.Constants.MAPPING_TYPE_ARRAY; +import static org.codelibs.fess.multimodal.MultiModalConstants.CONTENT_VECTOR_FIELD; +import static org.codelibs.fess.multimodal.MultiModalConstants.X_FESS_EMBEDDING; + import java.util.Map; import javax.annotation.PostConstruct; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.codelibs.core.lang.StringUtil; -import org.codelibs.fess.Constants; import org.codelibs.fess.ingest.Ingester; -import org.codelibs.fess.multimodal.MultiModalConstants; import org.codelibs.fess.multimodal.util.EmbeddingUtil; import org.codelibs.fess.util.ComponentUtil; public class EmbeddingIngester extends Ingester { private static final Logger logger = LogManager.getLogger(EmbeddingIngester.class); - protected String embeddingField; - @PostConstruct public void init() { - embeddingField = System.getProperty("clip.index.embedding_field", "content_vector"); - ComponentUtil.getFessConfig().addCrawlerMetadataNameMapping(MultiModalConstants.X_FESS_EMBEDDING, embeddingField, - Constants.MAPPING_TYPE_ARRAY, StringUtil.EMPTY); + ComponentUtil.getFessConfig().addCrawlerMetadataNameMapping(X_FESS_EMBEDDING, CONTENT_VECTOR_FIELD, MAPPING_TYPE_ARRAY, EMPTY); } @Override protected Map process(final Map target) { - if (target.containsKey(embeddingField)) { - logger.debug("[{}] : {}", embeddingField, target); - if (target.get(embeddingField) instanceof final String[] encodedEmbeddings) { + if (target.containsKey(CONTENT_VECTOR_FIELD)) { + logger.debug("[{}] : {}", CONTENT_VECTOR_FIELD, target); + if (target.get(CONTENT_VECTOR_FIELD) instanceof final String[] encodedEmbeddings) { final float[] embedding = EmbeddingUtil.decodeFloatArray(encodedEmbeddings[0]); logger.debug("embedding:{}", embedding); - target.put(embeddingField, embedding); + target.put(CONTENT_VECTOR_FIELD, embedding); } else { - logger.warn("{} is not an array.", embeddingField); + logger.warn("{} is not an array.", CONTENT_VECTOR_FIELD); } } return target; diff --git a/src/main/java/org/codelibs/fess/multimodal/query/MultiModalPhraseQueryCommand.java b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalPhraseQueryCommand.java new file mode 100644 index 0000000..6d64b8c --- /dev/null +++ b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalPhraseQueryCommand.java @@ -0,0 +1,60 @@ +/* + * Copyright 2012-2024 CodeLibs Project and the Others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language + * governing permissions and limitations under the License. + */ +package org.codelibs.fess.multimodal.query; + +import static org.codelibs.fess.Constants.DEFAULT_FIELD; +import static org.codelibs.fess.multimodal.MultiModalConstants.SEARCHER; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.PhraseQuery; +import org.codelibs.fess.entity.QueryContext; +import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher; +import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher.SearchContext; +import org.codelibs.fess.mylasta.direction.FessConfig; +import org.codelibs.fess.query.PhraseQueryCommand; +import org.codelibs.fess.util.ComponentUtil; +import org.opensearch.index.query.QueryBuilder; + +public class MultiModalPhraseQueryCommand extends PhraseQueryCommand { + + private static final Logger logger = LogManager.getLogger(MultiModalPhraseQueryCommand.class); + + @Override + protected QueryBuilder convertPhraseQuery(final FessConfig fessConfig, final QueryContext context, final PhraseQuery phraseQuery, + final float boost, final String field, final String[] texts) { + final SearchContext searchContext = getSearchContext(); + + if (!DEFAULT_FIELD.equals(field) || searchContext == null) { + return super.convertPhraseQuery(fessConfig, context, phraseQuery, boost, field, texts); + } + + final String text = String.join(" ", texts); + final QueryBuilder queryBuilder = + new MultiModalQueryBuilder.Builder().query(text).minScore(searchContext.getParams().getMinScore()).build().toQueryBuilder(); + context.addFieldLog(field, text); + context.addHighlightedQuery(text); + if (logger.isDebugEnabled()) { + logger.debug("KNNQueryBuilder: {}", queryBuilder); + } + return queryBuilder; + } + + protected SearchContext getSearchContext() { + final MultiModalSearcher searcher = ComponentUtil.getComponent(SEARCHER); + return searcher.getContext(); + } +} diff --git a/src/main/java/org/codelibs/fess/multimodal/query/MultiModalQueryBuilder.java b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalQueryBuilder.java new file mode 100644 index 0000000..f65d226 --- /dev/null +++ b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalQueryBuilder.java @@ -0,0 +1,64 @@ +/* + * Copyright 2012-2024 CodeLibs Project and the Others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language + * governing permissions and limitations under the License. + */ +package org.codelibs.fess.multimodal.query; + +import static org.codelibs.fess.multimodal.MultiModalConstants.CAS_CLIENT; +import static org.codelibs.fess.multimodal.MultiModalConstants.CONTENT_VECTOR_FIELD; + +import org.codelibs.fess.multimodal.client.CasClient; +import org.codelibs.fess.multimodal.index.query.KNNQueryBuilder; +import org.codelibs.fess.util.ComponentUtil; +import org.opensearch.index.query.QueryBuilder; + +public class MultiModalQueryBuilder { + + protected String query; + protected Float minScore; + + private MultiModalQueryBuilder() { + // nothing + } + + public static class Builder { + + private String query; + private Float minScore; + + public Builder query(final String query) { + this.query = query; + return this; + } + + public Builder minScore(final Float minScore) { + this.minScore = minScore; + return this; + } + + public MultiModalQueryBuilder build() { + final MultiModalQueryBuilder builder = new MultiModalQueryBuilder(); + builder.query = query; + builder.minScore = minScore; + return builder; + } + } + + public QueryBuilder toQueryBuilder() { + final CasClient client = ComponentUtil.getComponent(CAS_CLIENT); + final float[] embedding = client.getTextEmbedding(query); + return new KNNQueryBuilder.Builder().field(CONTENT_VECTOR_FIELD).vector(embedding).minScore(minScore).build(); + } + +} diff --git a/src/main/java/org/codelibs/fess/multimodal/query/MultiModalTermQueryCommand.java b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalTermQueryCommand.java new file mode 100644 index 0000000..882ee29 --- /dev/null +++ b/src/main/java/org/codelibs/fess/multimodal/query/MultiModalTermQueryCommand.java @@ -0,0 +1,59 @@ +/* + * Copyright 2012-2024 CodeLibs Project and the Others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language + * governing permissions and limitations under the License. + */ +package org.codelibs.fess.multimodal.query; + +import static org.codelibs.fess.Constants.DEFAULT_FIELD; +import static org.codelibs.fess.multimodal.MultiModalConstants.SEARCHER; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TermQuery; +import org.codelibs.fess.entity.QueryContext; +import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher; +import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher.SearchContext; +import org.codelibs.fess.mylasta.direction.FessConfig; +import org.codelibs.fess.query.TermQueryCommand; +import org.codelibs.fess.util.ComponentUtil; +import org.opensearch.index.query.QueryBuilder; + +public class MultiModalTermQueryCommand extends TermQueryCommand { + + private static final Logger logger = LogManager.getLogger(MultiModalTermQueryCommand.class); + + @Override + protected QueryBuilder convertDefaultTermQuery(final FessConfig fessConfig, final QueryContext context, final TermQuery termQuery, + final float boost, final String field, final String text) { + final SearchContext searchContext = getSearchContext(); + + if (!DEFAULT_FIELD.equals(field) || searchContext == null) { + return super.convertDefaultTermQuery(fessConfig, context, termQuery, boost, field, text); + } + + final QueryBuilder queryBuilder = + new MultiModalQueryBuilder.Builder().query(text).minScore(searchContext.getParams().getMinScore()).build().toQueryBuilder(); + context.addFieldLog(field, text); + context.addHighlightedQuery(text); + if (logger.isDebugEnabled()) { + logger.debug("KNNQueryBuilder: {}", queryBuilder); + } + return queryBuilder; + } + + protected SearchContext getSearchContext() { + final MultiModalSearcher searcher = ComponentUtil.getComponent(SEARCHER); + return searcher.getContext(); + } +} diff --git a/src/main/java/org/codelibs/fess/multimodal/rank/fusion/MultiModalSearcher.java b/src/main/java/org/codelibs/fess/multimodal/rank/fusion/MultiModalSearcher.java new file mode 100644 index 0000000..420ab0e --- /dev/null +++ b/src/main/java/org/codelibs/fess/multimodal/rank/fusion/MultiModalSearcher.java @@ -0,0 +1,224 @@ +/* + * Copyright 2012-2024 CodeLibs Project and the Others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language + * governing permissions and limitations under the License. + */ +package org.codelibs.fess.multimodal.rank.fusion; + +import static org.codelibs.fess.multimodal.MultiModalConstants.MIN_SCORE; + +import java.util.Locale; +import java.util.Map; + +import javax.annotation.PostConstruct; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.codelibs.core.lang.StringUtil; +import org.codelibs.fess.entity.FacetInfo; +import org.codelibs.fess.entity.GeoInfo; +import org.codelibs.fess.entity.HighlightInfo; +import org.codelibs.fess.entity.SearchRequestParams; +import org.codelibs.fess.mylasta.action.FessUserBean; +import org.codelibs.fess.rank.fusion.DefaultSearcher; +import org.codelibs.fess.rank.fusion.SearchResult; +import org.codelibs.fess.util.ComponentUtil; +import org.dbflute.optional.OptionalThing; + +public class MultiModalSearcher extends DefaultSearcher { + private static final Logger logger = LogManager.getLogger(MultiModalSearcher.class); + + protected ThreadLocal contextLocal = new ThreadLocal<>(); + + protected Float minScore; + + @PostConstruct + public void register() { + if (logger.isInfoEnabled()) { + logger.info("Load {}", this.getClass().getSimpleName()); + } + + final String minScoreValue = System.getProperty(MIN_SCORE); + if (StringUtil.isNotBlank(minScoreValue)) { + try { + minScore = Float.valueOf(minScoreValue); + } catch (final NumberFormatException e) { + logger.debug("Failed to parse {}.", minScoreValue, e); + minScore = null; + } + } else { + minScore = null; + } + + ComponentUtil.getRankFusionProcessor().register(this); + } + + @Override + protected SearchResult search(final String query, final SearchRequestParams params, final OptionalThing userBean) { + try { + final SearchRequestParams reqParams = new SearchRequestParamsWrapper(params, minScore); + createContext(query, reqParams, userBean); + return super.search(query, reqParams, userBean); + } finally { + closeContext(); + } + } + + public SearchContext createContext(final String query, final SearchRequestParams params, final OptionalThing userBean) { + if (contextLocal.get() != null) { + logger.warn("The context exists: {}", contextLocal.get()); + contextLocal.remove(); + } + final SearchContext context = new SearchContext(query, params, userBean); + contextLocal.set(context); + return context; + } + + public void closeContext() { + if (contextLocal.get() == null) { + logger.warn("The context does not exist."); + } else { + contextLocal.remove(); + } + } + + public SearchContext getContext() { + return contextLocal.get(); + } + + public static class SearchContext { + + private final String query; + private final SearchRequestParams params; + private final OptionalThing userBean; + + public SearchContext(final String query, final SearchRequestParams params, final OptionalThing userBean) { + this.query = query; + this.params = params; + this.userBean = userBean; + } + + public String getQuery() { + return query; + } + + public SearchRequestParams getParams() { + return params; + } + + public OptionalThing getUserBean() { + return userBean; + } + + @Override + public String toString() { + return "SemanticSearchContext [query=" + query + ", params=" + params + ", userBean=" + userBean.orElse(null) + "]"; + } + + } + + protected static class SearchRequestParamsWrapper extends SearchRequestParams { + private final SearchRequestParams parent; + private final Float minScore; + + protected SearchRequestParamsWrapper(final SearchRequestParams params, final Float minScore) { + this.parent = params; + this.minScore = minScore; + } + + @Override + public String getQuery() { + return parent.getQuery(); + } + + @Override + public Map getFields() { + return parent.getFields(); + } + + @Override + public Map getConditions() { + return parent.getConditions(); + } + + @Override + public String[] getLanguages() { + return parent.getLanguages(); + } + + @Override + public GeoInfo getGeoInfo() { + return null; + } + + @Override + public FacetInfo getFacetInfo() { + return null; + } + + @Override + public HighlightInfo getHighlightInfo() { + return null; + } + + @Override + public String getSort() { + return parent.getSort(); + } + + @Override + public int getStartPosition() { + return parent.getStartPosition(); + } + + @Override + public int getPageSize() { + return parent.getPageSize(); + } + + @Override + public int getOffset() { + return parent.getOffset(); + } + + @Override + public String[] getExtraQueries() { + return parent.getExtraQueries(); + } + + @Override + public Object getAttribute(final String name) { + return parent.getAttribute(name); + } + + @Override + public Locale getLocale() { + return parent.getLocale(); + } + + @Override + public SearchRequestType getType() { + return parent.getType(); + } + + @Override + public String getSimilarDocHash() { + return parent.getSimilarDocHash(); + } + + @Override + public Float getMinScore() { + return minScore; + } + } +} diff --git a/src/main/resources/fess_query+phraseQueryCommand.xml b/src/main/resources/fess_query+phraseQueryCommand.xml new file mode 100644 index 0000000..23e0d9e --- /dev/null +++ b/src/main/resources/fess_query+phraseQueryCommand.xml @@ -0,0 +1,9 @@ + + + + + + + diff --git a/src/main/resources/fess_query+termQueryCommand.xml b/src/main/resources/fess_query+termQueryCommand.xml new file mode 100644 index 0000000..4e00468 --- /dev/null +++ b/src/main/resources/fess_query+termQueryCommand.xml @@ -0,0 +1,9 @@ + + + + + + + diff --git a/src/main/resources/fess_rankfusion++.xml b/src/main/resources/fess_rankfusion++.xml new file mode 100644 index 0000000..2a32e14 --- /dev/null +++ b/src/main/resources/fess_rankfusion++.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/src/test/java/org/codelibs/fess/multimodal/client/CasClientTest.java b/src/test/java/org/codelibs/fess/multimodal/client/CasClientTest.java index ed8bacd..ece1a3e 100644 --- a/src/test/java/org/codelibs/fess/multimodal/client/CasClientTest.java +++ b/src/test/java/org/codelibs/fess/multimodal/client/CasClientTest.java @@ -46,4 +46,15 @@ public void test_getImageEmbedding() throws Exception { logger.warning(e.getMessage()); } } + + public void test_getTextEmbedding() throws Exception { + final CasClient client = new CasClient(); + client.init(); + try { + final float[] embedding = client.getTextEmbedding("running dogs"); + assertEquals(512, embedding.length); + } catch (final CurlException e) { + logger.warning(e.getMessage()); + } + } } \ No newline at end of file diff --git a/src/test/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngesterTest.java b/src/test/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngesterTest.java index 71a78fc..3c694e2 100644 --- a/src/test/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngesterTest.java +++ b/src/test/java/org/codelibs/fess/multimodal/ingest/EmbeddingIngesterTest.java @@ -15,36 +15,36 @@ */ package org.codelibs.fess.multimodal.ingest; +import static org.codelibs.fess.multimodal.MultiModalConstants.CONTENT_VECTOR_FIELD; + import java.util.HashMap; import java.util.Map; import org.dbflute.utflute.core.PlainTestCase; public class EmbeddingIngesterTest extends PlainTestCase { - private static final String VECTOR_FIELD = "vector_field"; public void test_process() { final EmbeddingIngester ingester = new EmbeddingIngester(); - ingester.embeddingField = VECTOR_FIELD; final Map target = new HashMap<>(); Map result = ingester.process(target); assertEquals(0, result.size()); target.clear(); - target.put(VECTOR_FIELD, new String[] { "P4AAAEAAAABAQAAA" }); + target.put(CONTENT_VECTOR_FIELD, new String[] { "P4AAAEAAAABAQAAA" }); result = ingester.process(target); assertEquals(1, result.size()); - final float[] array = (float[]) result.get(VECTOR_FIELD); + final float[] array = (float[]) result.get(CONTENT_VECTOR_FIELD); assertEquals(3, array.length); assertEquals(1.0f, array[0]); assertEquals(2.0f, array[1]); assertEquals(3.0f, array[2]); target.clear(); - target.put(VECTOR_FIELD, "P4AAAEAAAABAQAAA"); + target.put(CONTENT_VECTOR_FIELD, "P4AAAEAAAABAQAAA"); result = ingester.process(target); assertEquals(1, result.size()); - assertEquals("P4AAAEAAAABAQAAA", result.get(VECTOR_FIELD)); + assertEquals("P4AAAEAAAABAQAAA", result.get(CONTENT_VECTOR_FIELD)); } }