Skip to content

Commit

Permalink
fix ITests and update in AI extract structured response schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszsocha2 committed Feb 18, 2025
1 parent fd9d239 commit 5275138
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 46 deletions.
62 changes: 36 additions & 26 deletions src/intTest/java/com/box/sdk/BoxAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;

import com.eclipsesource.json.Json;
import com.eclipsesource.json.JsonObject;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -111,8 +111,8 @@ public void askAIMultipleItems() throws InterruptedException {
public void askAITextGenItemWithDialogueHistory() throws ParseException, InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
String fileName = "[askAITextGenItemWithDialogueHistory] Test File.txt";
Date date1 = BoxDateFormat.parse("2013-05-16T15:27:57-07:00");
Date date2 = BoxDateFormat.parse("2013-05-16T15:26:57-07:00");
Date date1 = BoxDateFormat.parse("2021-01-01T00:00:00Z");
Date date2 = BoxDateFormat.parse("2021-02-01T00:00:00Z");

BoxFile uploadedFile = uploadFileToUniqueFolder(api, fileName, "Test file");
try {
Expand Down Expand Up @@ -148,28 +148,25 @@ public void askAITextGenItemWithDialogueHistory() throws ParseException, Interru
@Test
public void getAIAgentDefaultConfiguration() {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK,
"en", "openai__gpt_3_5_turbo");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK);
BoxAIAgentAsk askAgent = (BoxAIAgentAsk) agent;

assertThat(askAgent.getType(), is(equalTo(BoxAIAgentAsk.TYPE)));
assertThat(askAgent.getBasicText().getModel(), is(equalTo("openai__gpt_3_5_turbo")));
assertThat(askAgent.getBasicText().getModel(), is(notNullValue()));

BoxAIAgent agent2 = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.TEXT_GEN,
"en", "openai__gpt_3_5_turbo");
BoxAIAgent agent2 = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.TEXT_GEN);
BoxAIAgentTextGen textGenAgent = (BoxAIAgentTextGen) agent2;

assertThat(textGenAgent.getType(), is(equalTo(BoxAIAgentTextGen.TYPE)));
assertThat(textGenAgent.getBasicGen().getModel(), is(equalTo("openai__gpt_3_5_turbo")));
assertThat(textGenAgent.getBasicGen().getModel(), is(notNullValue()));
}

@Test
public void askAISingleItemWithAgent() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
String fileName = "[askAISingleItem] Test File.txt";
BoxFile uploadedFile = uploadFileToUniqueFolder(api, fileName, "Test file");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK,
"en", "openai__gpt_3_5_turbo_16k");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK);
BoxAIAgentAsk askAgent = (BoxAIAgentAsk) agent;

try {
Expand Down Expand Up @@ -199,8 +196,10 @@ public void askAISingleItemWithAgent() throws InterruptedException {
@Test
public void aiExtract() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT);
BoxAIAgentExtract agentExtract = (BoxAIAgentExtract) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtract.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtract] Test File.txt",
"My name is John Doe. I live in San Francisco. I was born in 1990. I work at Box.");
Expand All @@ -224,8 +223,10 @@ public void aiExtract() throws InterruptedException {
@Test
public void aiExtractStructuredWithFields() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED);
BoxAIAgentExtractStructured agentExtractStructured = (BoxAIAgentExtractStructured) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtractStructured.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtractStructuredWithFields] Test File.txt",
"My name is John Doe. I was born in 4th July 1990. I am 34 years old. My hobby is guitar.");
Expand Down Expand Up @@ -259,12 +260,16 @@ public void aiExtractStructuredWithFields() throws InterruptedException {
"What is your hobby?")
),
agentExtractStructured);
JsonObject sourceJson = response.getSourceJson();
assertThat(sourceJson.get("firstName").asString(), is(equalTo("John")));
assertThat(sourceJson.get("lastName").asString(), is(equalTo("Doe")));
assertThat(sourceJson.get("dateOfBirth").asString(), is(equalTo("1990-07-04")));
assertThat(sourceJson.get("age").asInt(), is(equalTo(34)));
assertThat(sourceJson.get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getSourceJson().get("answer"), is(equalTo(response.getAnswer())));

assertThat(response.getAnswer().get("firstName").asString(), is(equalTo("John")));
assertThat(response.getAnswer().get("lastName").asString(), is(equalTo("Doe")));
assertThat(response.getAnswer().get("dateOfBirth").asString(), is(equalTo("1990-07-04")));
assertThat(response.getAnswer().get("age").asInt(), is(equalTo(34)));
assertThat(response.getAnswer().get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));

assertThat(response.getCompletionReason(), equalTo("done"));
assertThat(response.getCreatedAt(), is(notNullValue()));
}, 2, 2000);
} finally {
deleteFile(uploadedFile);
Expand All @@ -274,8 +279,10 @@ public void aiExtractStructuredWithFields() throws InterruptedException {
@Test
public void aiExtractStructuredWithMetadataTemplate() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED);
BoxAIAgentExtractStructured agentExtractStructured = (BoxAIAgentExtractStructured) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtractStructured.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtractStructuredWithMetadataTemplate] Test File.txt",
"My name is John Doe. I was born in 4th July 1990. I am 34 years old. My hobby is guitar.");
Expand Down Expand Up @@ -312,12 +319,15 @@ public void aiExtractStructuredWithMetadataTemplate() throws InterruptedExceptio
Collections.singletonList(new BoxAIItem(uploadedFile.getID(), BoxAIItem.Type.FILE)),
new BoxAIExtractMetadataTemplate(templateKey, "enterprise"),
agentExtractStructured);
JsonObject sourceJson = response.getSourceJson();
assertThat(sourceJson.get("firstName").asString(), is(equalTo("John")));
assertThat(sourceJson.get("lastName").asString(), is(equalTo("Doe")));
assertThat(sourceJson.get("dateOfBirth").asString(), is(equalTo("1990-07-04T00:00:00Z")));
assertThat(sourceJson.get("age").asInt(), is(equalTo(34)));
assertThat(sourceJson.get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getSourceJson().get("answer"), is(equalTo(response.getAnswer())));

assertThat(response.getAnswer().get("firstName").asString(), is(equalTo("John")));
assertThat(response.getAnswer().get("lastName").asString(), is(equalTo("Doe")));
assertThat(response.getAnswer().get("dateOfBirth").asString(), is(equalTo("1990-07-04T00:00:00Z")));
assertThat(response.getAnswer().get("age").asInt(), is(equalTo(34)));
assertThat(response.getAnswer().get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getCompletionReason(), equalTo("done"));
assertThat(response.getCreatedAt(), is(notNullValue()));
}, 2, 2000);
} finally {
deleteFile(uploadedFile);
Expand Down
16 changes: 12 additions & 4 deletions src/main/java/com/box/sdk/BoxAIAgentAsk.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,18 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "basic_text_multi", this.basicTextMulti.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text_multi", this.longTextMulti.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.basicTextMulti != null) {
jsonObject.add("basic_text_multi", this.basicTextMulti.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
if (this.longTextMulti != null) {
jsonObject.add("long_text_multi", this.longTextMulti.getJSONObject());
}
return jsonObject;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentAskBasicText.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ void parseJSONMember(JsonObject.Member member) {

public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentAskLongText.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,12 @@ void parseJSONMember(JsonObject.Member member) {

public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "embeddings", this.embeddings.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.embeddings != null) {
jsonObject.add("embeddings", this.embeddings.getJSONObject());
}
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentEmbeddings.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "strategy", this.strategy.getJSONObject());
if (this.strategy != null) {
jsonObject.add("strategy", this.strategy.getJSONObject());
}
return jsonObject;
}
}
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentExtract.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
return jsonObject;
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentExtractStructured.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
return jsonObject;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentTextGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_gen", this.basicGen.getJSONObject());
if (this.basicGen != null) {
jsonObject.add("basic_gen", this.basicGen.getJSONObject());
}
return jsonObject;
}
}
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentTextGenBasicGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "content_template", this.contentTemplate);
JsonUtils.addIfNotNull(jsonObject, "embeddings", this.embeddings.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.embeddings != null) {
jsonObject.add("embeddings", this.embeddings.getJSONObject());
}
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/box/sdk/BoxAIDialogueEntry.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public JsonObject getJSONObject() {
.add("answer", this.answer);

if (this.createdAt != null) {
itemJSON.add("created_at", this.createdAt.toString());
itemJSON.add("created_at", BoxDateFormat.format(this.createdAt));
}

return itemJSON;
Expand Down
54 changes: 54 additions & 0 deletions src/main/java/com/box/sdk/BoxAIExtractStructuredResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import com.eclipsesource.json.Json;
import com.eclipsesource.json.JsonObject;
import com.eclipsesource.json.JsonValue;
import java.text.ParseException;
import java.util.Date;

/**
* AI response to a user request.
*/
public class BoxAIExtractStructuredResponse extends BoxJSONObject {
private final JsonObject sourceJson;
private JsonObject answer;
private String completionReason;
private Date createdAt;

/**
* Constructs a BoxAIResponse object.
Expand Down Expand Up @@ -35,4 +41,52 @@ public BoxAIExtractStructuredResponse(String json) {
public JsonObject getSourceJson() {
return sourceJson;
}

/**
* Gets the answer of the AI.
*
* @return the answer of the AI.
*/
public JsonObject getAnswer() {
return answer;
}

/**
* Gets reason the response finishes.
*
* @return the reason the response finishes.
*/
public String getCompletionReason() {
return completionReason;
}

/**
* Gets the ISO date formatted timestamp of when the answer to the prompt was created.
*
* @return The ISO date formatted timestamp of when the answer to the prompt was created.
*/
public Date getCreatedAt() {
return createdAt;
}

/**
* {@inheritDoc}
*/
@Override
void parseJSONMember(JsonObject.Member member) {
JsonValue value = member.getValue();
String memberName = member.getName();
try {
if (memberName.equals("answer")) {
this.answer = value.asObject();
} else if (memberName.equals("completion_reason")) {
this.completionReason = value.asString();
} else if (memberName.equals("created_at")) {
this.createdAt = BoxDateFormat.parse(value.asString());
}
} catch (ParseException e) {
assert false : "A ParseException indicates a bug in the SDK.";
}
}

}
12 changes: 8 additions & 4 deletions src/test/java/com/box/sdk/BoxAITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ public void testSendAITexGenRequestWithNoDialogueHistorySuccess() {
public void testSendAITexGenRequestWithDialogueHistorySuccess() throws ParseException {
final String fileId = "12345";
final String prompt = "What is the name of the file?";
Date date1 = BoxDateFormat.parse("2021-01-01T00:00:00Z");
Date date2 = BoxDateFormat.parse("2022-01-01T00:00:00Z");

List<BoxAIDialogueEntry> dialogueHistory = new ArrayList<>();
dialogueHistory.add(
new BoxAIDialogueEntry("What is the name of the file?", "Test file")
new BoxAIDialogueEntry("What is the name of the file?", "Test file", date1)
);
dialogueHistory.add(
new BoxAIDialogueEntry("What is the size of the file?", "10kb")
new BoxAIDialogueEntry("What is the size of the file?", "10kb", date2)
);

String expectedRequestBody = String.format(
Expand All @@ -109,8 +111,10 @@ public void testSendAITexGenRequestWithDialogueHistorySuccess() throws ParseExce
+ " {\"id\": \"%s\", \"type\": \"file\"}\n"
+ " ],\n"
+ " \"dialogue_history\": [\n"
+ " {\"prompt\": \"What is the name of the file?\", \"answer\": \"Test file\"},\n"
+ " {\"prompt\": \"What is the size of the file?\", \"answer\": \"10kb\"}\n"
+ " {\"prompt\": \"What is the name of the file?\", \"answer\": \"Test file\","
+ " \"created_at\" : \"2021-01-01T00:00:00Z\"},\n"
+ " {\"prompt\": \"What is the size of the file?\", \"answer\": \"10kb\","
+ " \"created_at\" : \"2022-01-01T00:00:00Z\"}\n"
+ " ]\n"
+ "}",
prompt, fileId
Expand Down

0 comments on commit 5275138

Please sign in to comment.