Skip to content

Commit

Permalink
#10 updating the EmbeddingContentListener to automatically embed cont…
Browse files Browse the repository at this point in the history
…ent based on config
  • Loading branch information
wezell committed Jan 5, 2024
1 parent 1ebdaa7 commit 3e4c752
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 26 deletions.
21 changes: 15 additions & 6 deletions src/main/java/com/dotcms/ai/api/CompletionsAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import com.dotcms.ai.rest.forms.CompletionsForm;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Lazy;

import java.io.OutputStream;

public interface CompletionsAPI {
Expand All @@ -19,8 +18,8 @@ static CompletionsAPI impl() {


/**
* this method takes the query/prompt, searches dotCMS content for matching
* embeddings and then returns an AI summary based on the matching content in dotCMS
* this method takes the query/prompt, searches dotCMS content for matching embeddings and then returns an AI
* summary based on the matching content in dotCMS
*
* @param searcher
* @return
Expand All @@ -29,8 +28,8 @@ static CompletionsAPI impl() {


/**
* this method takes the query/prompt, searches dotCMS content for matching
* embeddings and then streams the AI response based on the matching content in dotCMS
* this method takes the query/prompt, searches dotCMS content for matching embeddings and then streams the AI
* response based on the matching content in dotCMS
*
* @param searcher
* @return
Expand All @@ -55,7 +54,17 @@ static CompletionsAPI impl() {
*/
JSONObject raw(JSONObject promptJSON);


/**
* this method takes a prompt in the form of parameters and returns a json AI response based on the parameters
* passed in.
*
* @param systemPrompt
* @param userPrompt
* @param model
* @param temperature
* @param maxTokens
* @return
*/
JSONObject prompt(String systemPrompt, String userPrompt, String model, float temperature, int maxTokens);


Expand Down
81 changes: 77 additions & 4 deletions src/main/java/com/dotcms/ai/api/EmbeddingsAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import com.dotmarketing.portlets.contentlet.model.Contentlet;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.Tuple2;

import javax.validation.constraints.NotNull;
import java.util.List;
import java.util.Map;
import javax.validation.constraints.NotNull;

public interface EmbeddingsAPI {

Expand All @@ -24,37 +23,111 @@ static EmbeddingsAPI impl() {

void shutdown();

/**
* given a contentlet, a list of fields and an index, this method will do its best to turn that content into an
* index-able string that then gets vectorized and stored in postgres.
* <p>
* Important - if you send in an empty list of fields to index, the method will try to intelligently(tm) pick how to
* index the content. For example, if you send in a fileAsset or dotAsset, it will try to index the content of the
* file. If you send a htmlPageAsset, it will render the page and index the rendered page. If you send a contentlet
* with a Storyblock or wysiwyg field, it will render those and index the resultant content.
*
* @param contentlet
* @param fields
* @param index
* @return
*/
boolean generateEmbeddingsforContent(Contentlet contentlet, List<Field> fields, String index);

/**
* this method takes a contentlet and a velocity template, generates a velocity context that includes the
* $contentlet in it and indexes the rendered result.
*
* @param contentlet
* @param velocityTemplate
* @param indexName
* @return
*/
boolean generateEmbeddingsforContent(@NotNull Contentlet contentlet, String velocityTemplate, String indexName);

/**
* Takes a DTO object and based on its properties deletes from the embeddings index.
*
* @param dto
* @return
*/
int deleteEmbedding(EmbeddingsDTO dto);


/**
* This method takes comma or line separated string of content types and optionally fields and returns
*
* @param typeAndFieldParam a map of <contentTypeVar, List<FieldsToIndex>>
* @return
*/
Map<String, List<Field>> parseTypesAndFields(String typeAndFieldParam);

/**
* This method takes a list of semantic search results, which are just fragements of content and returns a json
* object of a list of the actual contentlets and fragements that matched the result. The idea is to provide the
* ability to show exactly which contentlets matched the semantic query and specifically, which fragments in that
* content matched.
*
* @param searcher
* @param searchResults
* @return
*/
JSONObject reduceChunksToContent(EmbeddingsDTO searcher, List<EmbeddingsDTO> searchResults);

/**
* Takes a searcher DTO and returns a JSON object that is a list of matching contentlets and the fragments that
* matched.
*
* @param searcher
* @return
*/
JSONObject searchForContent(EmbeddingsDTO searcher);

/**
* returns a list of matching content+embeddings from the dot_embeddings table based on the searcher dto
*
* @param searcher
* @return
*/
List<EmbeddingsDTO> getEmbeddingResults(EmbeddingsDTO searcher);

/**
* returns a count of matching content+embeddings based on the searcher dto
*
* @param searcher
* @return
*/
long countEmbeddings(EmbeddingsDTO searcher);

/**
* returns a map of all the available dot_embeddings 'indexes' plus the count of embeddings in them
*
* @return
*/
Map<String, Map<String, Object>> countEmbeddingsByIndex();

/**
* drops the dot_embeddings table
*/
void dropEmbeddingsTable();

/**
* inits pg_vector and builds the dot_embeddings table
*/
void initEmbeddingsTable();

/**
* Returns
* Takes a string and returns the embeddings value for the string
*
* @param content
* @return
*/
Tuple2<Integer, List<Float>> pullOrGenerateEmbeddings(String content);



}
2 changes: 1 addition & 1 deletion src/main/java/com/dotcms/ai/api/EmbeddingsAPIImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public Map<String, List<Field>> parseTypesAndFields(final String typeAndFieldPa
}
List<Field> fields = typesAndFields.getOrDefault(type.get().variable(), new ArrayList<>());

Optional<Field> field = Try.of(() -> type.get().fieldMap().get(typeOptField[1])).toJavaOptional();
Optional<Field> field = Try.of(() -> type.get().fields().stream().filter(f->f.variable().equalsIgnoreCase(typeOptField[1])).findFirst()).getOrElse(Optional.empty());
if (field.isPresent()) {
fields.add(field.get());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import com.dotcms.content.elasticsearch.business.event.ContentletCheckinEvent;
import com.dotcms.content.elasticsearch.business.event.ContentletDeletedEvent;
import com.dotcms.content.elasticsearch.business.event.ContentletPublishEvent;
import com.dotcms.contenttype.model.field.Field;
import com.dotcms.system.event.local.model.Subscriber;
import com.dotmarketing.beans.Host;
import com.dotmarketing.business.APILocator;
import com.dotmarketing.portlets.contentlet.model.Contentlet;
import com.dotmarketing.portlets.contentlet.model.ContentletListener;
import com.dotmarketing.util.json.JSONObject;
import io.vavr.control.Try;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

Expand Down Expand Up @@ -87,11 +90,13 @@ JSONObject getConfigJson(Contentlet contentlet) {
.getOrElse(APILocator.systemHost());

return Try.of(() -> new JSONObject(ConfigService.INSTANCE.config(host).getConfig(AppKeys.LISTENER_INDEXER)))
.onFailure(e->Logger.warn(EmbeddingContentListener.class, "error in json config from app:" + e.getMessage()))
.getOrElse(new JSONObject());
}

/**
* Adds the content to the embeddings index based on the JSON configuration in the app
* Adds the content to the embeddings index based on the JSON configuration in the app. The JSON key is the
* indexName and the property is a comma or br delimited string of contentType.field to index
*
* @param contentlet
*/
Expand All @@ -105,8 +110,9 @@ void addToIndexesIfNeeded(Contentlet contentlet) {

for (Entry<String, Object> entry : (Set<Entry<String, Object>>) config.entrySet()) {
final String indexName = entry.getKey();
EmbeddingsAPI.impl()
.parseTypesAndFields((String) entry.getValue()).entrySet()
Map<String, List<Field>> typesAndFields = EmbeddingsAPI.impl()
.parseTypesAndFields((String) entry.getValue());
typesAndFields.entrySet()
.stream()
.filter(e -> contentType.equalsIgnoreCase(e.getKey()))
.forEach(e ->
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/com/dotcms/ai/util/ContentToStringUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ private Optional<String> parseText(@NotNull String val) {
? val.replaceAll("\\s+", " ")
: null;

if (UtilMethods.isEmpty(val) || val.length() < ConfigService.INSTANCE.config().getConfigInteger(AppKeys.EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX)) {
return Optional.empty();
}
return Optional.of(val);
return Optional.ofNullable(val);
}

private Optional<String> parseBlockEditor(@NotNull String val) {
Expand Down Expand Up @@ -235,7 +232,12 @@ public Optional<String> parseFields(@NotNull Contentlet contentlet, @NotNull Lis
parseField(contentlet, field)
.ifPresent(s -> builder.append(s).append(" "));
}
return (builder.length() > 0) ? Optional.of(builder.toString()) : Optional.empty();

if (builder.length() < ConfigService.INSTANCE.config().getConfigInteger(AppKeys.EMBEDDINGS_MINIMUM_TEXT_LENGTH_TO_INDEX)) {
return Optional.empty();
}

return Optional.of(builder.toString());


}
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/com/dotcms/ai/util/OpenAIModel.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
package com.dotcms.ai.util;

import com.dotmarketing.exception.DotRuntimeException;

import java.util.Arrays;
import java.util.stream.Collectors;

public enum OpenAIModel {




GPT_3_5_TURBO("gpt-3.5-turbo", 3000, 3500, 4096, true),
GPT_3_5_TURBO_16k("gpt-3.5-turbo-16k", 180000, 3500, 16384, true),
GPT_4("gpt-4", 10000, 200, 8191, true),
GPT_4_TURBO("gpt-4-1106-preview", 10000, 200, 128000 , true),
GPT_4_TURBO("gpt-4-1106-preview", 10000, 200, 128000, true),
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", 1000000, 3000, 8191, false),
DALL_E_2("dall-e-2", 0, 50, 0, false),
DALL_E_3("dall-e-3", 0, 50, 0, false);
Expand All @@ -23,12 +20,13 @@ public enum OpenAIModel {
public final int maxTokens;
public final String modelName;
public final boolean completionModel;

OpenAIModel(String modelName, int tokensPerMinute, int apiPerMinute, int maxTokens, boolean completionModel) {
this.modelName = modelName;
this.tokensPerMinute = tokensPerMinute;
this.apiPerMinute = apiPerMinute;
this.maxTokens = maxTokens;
this.completionModel=completionModel;
this.completionModel = completionModel;
}

public static OpenAIModel resolveModel(String modelIn) {
Expand All @@ -38,11 +36,13 @@ public static OpenAIModel resolveModel(String modelIn) {
return openAiModel;
}
}
throw new DotRuntimeException("Unable to parse model:'" + modelIn + "'. Only " + supportedModels() + " are supported ");
throw new DotRuntimeException(
"Unable to parse model:'" + modelIn + "'. Only " + supportedModels() + " are supported ");
}

private static String supportedModels() {
return String.join(", ", Arrays.asList(OpenAIModel.values()).stream().map(o -> o.modelName).collect(Collectors.toList()));
return String.join(", ",
Arrays.asList(OpenAIModel.values()).stream().map(o -> o.modelName).collect(Collectors.toList()));
}

public long minIntervalBetweenCalls() {
Expand Down

0 comments on commit 3e4c752

Please sign in to comment.