Skip to content

Commit

Permalink
Merge pull request #3264 from jsonwan/github_feature/ai
Browse files Browse the repository at this point in the history
perf: AI小鲸相关功能优化 #3258
  • Loading branch information
jsonwan authored Oct 28, 2024
2 parents 0155169 + b8b5b42 commit f347971
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
public class JobOpenAiClientBuilderFactory implements OpenAiClientBuilderFactory {
@Override
public OpenAiClient.Builder<DefaultOpenAiClient, DefaultOpenAiClient.Builder> get() {
log.info("Creating a new instance of the DefaultOpenAiClient.Builder");
return DefaultOpenAiClient.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public Response<List<AIChatRecord>> getLatestChatHistoryList(String username,
Integer start,
Integer length) {
if (!aiChatHistoryService.existsChatHistory(username)) {
AIChatHistoryDTO greetingChatHistory = getGreetingChatHistory(username);
AIChatHistoryDTO greetingChatHistory = getGreetingChatHistory(username, appResourceScope.getAppId());
Long id = aiChatHistoryService.insertChatHistory(greetingChatHistory);
log.debug("Greeting chat history created, username={}, id={}", username, id);
}
Expand All @@ -115,9 +115,10 @@ public Response<List<AIChatRecord>> getLatestChatHistoryList(String username,
return Response.buildSuccessResp(aiChatRecordList);
}

private AIChatHistoryDTO getGreetingChatHistory(String username) {
private AIChatHistoryDTO getGreetingChatHistory(String username, Long appId) {
AIChatHistoryDTO greetingChatHistory = new AIChatHistoryDTO();
greetingChatHistory.setUsername(username);
greetingChatHistory.setAppId(appId);
greetingChatHistory.setUserInput("");
greetingChatHistory.setStartTime(System.currentTimeMillis());
greetingChatHistory.setPromptTemplateId(null);
Expand All @@ -138,7 +139,7 @@ public Response<AIChatRecord> generalChat(String username,
String scopeType,
String scopeId,
AIGeneralChatReq req) {
AIChatRecord aiChatRecord = chatService.chatWithAI(username, req.getContent());
AIChatRecord aiChatRecord = chatService.chatWithAI(username, appResourceScope.getAppId(), req.getContent());
return Response.buildSuccessResp(aiChatRecord);
}

Expand All @@ -148,7 +149,12 @@ public Response<AIChatRecord> checkScript(String username,
String scopeType,
String scopeId,
AICheckScriptReq req) {
AIChatRecord aiChatRecord = aiCheckScriptService.check(username, req.getType(), req.getContent());
AIChatRecord aiChatRecord = aiCheckScriptService.check(
username,
appResourceScope.getAppId(),
req.getType(),
req.getContent()
);
return Response.buildSuccessResp(aiChatRecord);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public AIChatHistoryDAOImpl(@Qualifier("job-analysis-dsl-context")
public Long insertAIChatHistory(AIChatHistoryDTO aiChatHistoryDTO) {
val query = dslContext.insertInto(defaultTable,
defaultTable.USERNAME,
defaultTable.APP_ID,
defaultTable.USER_INPUT,
defaultTable.PROMPT_TEMPLATE_ID,
defaultTable.AI_INPUT,
Expand All @@ -93,6 +94,7 @@ public Long insertAIChatHistory(AIChatHistoryDTO aiChatHistoryDTO) {
)
.values(
aiChatHistoryDTO.getUsername(),
aiChatHistoryDTO.getAppId(),
aiChatHistoryDTO.getUserInput(),
aiChatHistoryDTO.getPromptTemplateId(),
aiChatHistoryDTO.getAiInput(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public class AIChatHistoryDTO {
*/
private String username;

/**
* Job业务ID
*/
private Long appId;

/**
* 用户输入内容
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ public interface AIChatHistoryService {
* 构建AI聊天记录
*
* @param username 用户名
* @param appId Job业务ID
* @param startTime 开始时间
* @param aiPromptDTO AI提示符信息
* @param status 对话状态
* @param aiAnswer AI回答
* @return AI聊天记录
*/
AIChatHistoryDTO buildAIChatHistoryDTO(String username,
Long appId,
Long startTime,
AIPromptDTO aiPromptDTO,
Integer status,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ public interface AICheckScriptService {
* 检查脚本
*
* @param username 用户名
* @param appId Job业务ID
* @param type 脚本类型
* @param scriptContent 脚本内容
* @return AI对话记录
*/
AIChatRecord check(String username, Integer type, String scriptContent);
AIChatRecord check(String username, Long appId, Integer type, String scriptContent);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ public interface ChatService {
* 与AI聊天并处理聊天记录保存等逻辑
*
* @param username 用户名
* @param appId Job业务ID
* @param userInput 用户输入
* @return AI对话记录
*/
AIChatRecord chatWithAI(String username, String userInput);
AIChatRecord chatWithAI(String username, Long appId, String userInput);

/**
* 获取最近的聊天记录列表
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public AIChatRecord analyze(String username, Long appId, AIAnalyzeErrorReq req)
if (!taskContext.isTaskFail()) {
return getDirectlyAIChatRecord(
username,
appId,
aiPromptDTO,
aiMessageI18nService.getNotFailTaskAIAnswerMessage()
);
Expand All @@ -94,6 +95,7 @@ public AIChatRecord analyze(String username, Long appId, AIAnalyzeErrorReq req)
if (!taskContext.isTaskFail()) {
return getDirectlyAIChatRecord(
username,
appId,
aiPromptDTO,
aiMessageI18nService.getNotFailTaskAIAnswerMessage()
);
Expand All @@ -102,6 +104,6 @@ public AIChatRecord analyze(String username, Long appId, AIAnalyzeErrorReq req)
throw new InvalidParamException(ErrorCode.AI_ANALYZE_ERROR_ONLY_SUPPORT_SCRIPT_OR_FILE_STEP);
}
AIAnalyzeErrorContextDTO analyzeErrorContext = AIAnalyzeErrorContextDTO.fromAIAnalyzeErrorReq(req);
return getAIChatRecord(username, aiPromptDTO, analyzeErrorContext);
return getAIChatRecord(username, appId, aiPromptDTO, analyzeErrorContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import com.tencent.bk.job.analysis.model.web.resp.AIAnswer;
import com.tencent.bk.job.analysis.service.ai.AIChatHistoryService;
import com.tencent.bk.job.common.constant.ErrorCode;
import com.tencent.bk.job.common.util.I18nUtil;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.helpers.MessageFormatter;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -84,13 +86,13 @@ public void doHandleAIAnswer(Long recordId, String content, Throwable throwable)
);
} else {
// 3.对话异常
aiAnswer = AIAnswer.failAnswer(content, throwable.getMessage());
String errorContent = I18nUtil.getI18nMessage(String.valueOf(ErrorCode.BK_OPEN_AI_API_DATA_ERROR));
aiAnswer = AIAnswer.failAnswer(errorContent, throwable.getMessage());
int affectedRow = aiChatHistoryService.finishAIAnswer(recordId, aiAnswer);
String message = MessageFormatter.arrayFormat(
"AIAnswer finished(fail), recordId={}, length={}, affectedRow={}",
"AIAnswer finished(fail), recordId={}, affectedRow={}",
new Object[]{
recordId,
content == null ? 0 : content.length(),
affectedRow
}
).getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import com.tencent.bk.job.analysis.service.ai.context.model.MessagePartEvent;
import com.tencent.bk.job.analysis.util.ai.AIAnswerUtil;
import com.tencent.bk.job.common.constant.ErrorCode;
import com.tencent.bk.job.common.exception.ServiceException;
import com.tencent.bk.job.common.model.Response;
import com.tencent.bk.job.common.model.error.ErrorType;
import com.tencent.bk.job.common.util.TimeUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
Expand Down Expand Up @@ -72,13 +70,19 @@ public AsyncConsumerAndProducerPair buildAsyncConsumerAndProducerPair() {
try {
MessagePartEvent event = messageQueue.poll(90, TimeUnit.SECONDS);
if (event == null) {
throw new ServiceException(ErrorType.TIMEOUT, ErrorCode.BK_OPEN_AI_API_DATA_TIMEOUT);
Response<AIAnswer> respBody =
Response.buildCommonFailResp(ErrorCode.BK_OPEN_AI_API_DATA_TIMEOUT);
respBody.setData(AIAnswer.failAnswer(respBody.getErrorMsg(), respBody.getErrorMsg()));
AIAnswerUtil.setRequestIdAndWriteResp(outputStream, respBody);
break;
}
if (event.isEnd()) {
Throwable throwable = event.getThrowable();
if (throwable != null) {
log.warn("Receive end event with throwable", throwable);
Response<AIAnswer> respBody =
Response.buildCommonFailResp(ErrorCode.BK_OPEN_AI_API_DATA_ERROR);
respBody.setData(AIAnswer.failAnswer(respBody.getErrorMsg(), throwable.getMessage()));
AIAnswerUtil.setRequestIdAndWriteResp(outputStream, respBody);
}
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,34 @@ public AIBaseService(AIChatHistoryService aiChatHistoryService) {
* 使用AI提示符调用AI接口生成AI回答
*
* @param username 用户名
* @param appId Job业务ID
* @param aiPromptDTO AI提示符
* @return AI对话记录
*/
public AIChatRecord getAIChatRecord(String username,
Long appId,
AIPromptDTO aiPromptDTO) {
return getAIChatRecord(username, aiPromptDTO, null);
return getAIChatRecord(username, appId, aiPromptDTO, null);
}

/**
* 使用AI提示符调用AI接口生成AI回答(支持报错分析上下文)
*
* @param username 用户名
* @param appId Job业务ID
* @param aiPromptDTO AI提示符
* @param analyzeErrorContext 报错分析上下文信息
* @return AI对话记录
*/
public AIChatRecord getAIChatRecord(String username,
Long appId,
AIPromptDTO aiPromptDTO,
AIAnalyzeErrorContextDTO analyzeErrorContext) {
long startTime = System.currentTimeMillis();
// 1.插入初始聊天记录
AIChatHistoryDTO aiChatHistoryDTO = aiChatHistoryService.buildAIChatHistoryDTO(
username,
appId,
startTime,
aiPromptDTO,
AIChatStatusEnum.INIT.getStatus(),
Expand All @@ -88,16 +93,18 @@ public AIChatRecord getAIChatRecord(String username,
* 使用指定内容直接生成AI回答
*
* @param username 用户名
* @param appId Job业务ID
* @param aiPromptDTO AI提示符
* @param content 指定内容
* @return AI对话记录
*/
public AIChatRecord getDirectlyAIChatRecord(String username, AIPromptDTO aiPromptDTO, String content) {
public AIChatRecord getDirectlyAIChatRecord(String username, Long appId, AIPromptDTO aiPromptDTO, String content) {
long startTime = System.currentTimeMillis();
aiPromptDTO.setRenderedPrompt(buildAIDirectlyAnswerInput(content));
AIAnswer aiAnswer = new AIAnswer("0", "", content, System.currentTimeMillis());
AIChatHistoryDTO aiChatHistoryDTO = aiChatHistoryService.buildAIChatHistoryDTO(
username,
appId,
startTime,
aiPromptDTO,
AIChatStatusEnum.FINISHED.getStatus(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ public AIChatHistoryServiceImpl(AIChatHistoryDAO aiChatHistoryDAO) {

@Override
public AIChatHistoryDTO buildAIChatHistoryDTO(String username,
Long appId,
Long startTime,
AIPromptDTO aiPromptDTO,
Integer status,
AIAnswer aiAnswer) {
AIChatHistoryDTO aiChatHistoryDTO = new AIChatHistoryDTO();
aiChatHistoryDTO.setUsername(username);
aiChatHistoryDTO.setAppId(appId);
aiChatHistoryDTO.setUserInput(aiPromptDTO.getRawPrompt());
aiChatHistoryDTO.setPromptTemplateId(aiPromptDTO.getPromptTemplateId());
aiChatHistoryDTO.setAiInput(aiPromptDTO.getRenderedPrompt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ public AICheckScriptServiceImpl(CheckScriptAIPromptService checkScriptAIPromptSe
}

@Override
public AIChatRecord check(String username, Integer type, String scriptContent) {
public AIChatRecord check(String username, Long appId, Integer type, String scriptContent) {
AIPromptDTO aiPromptDTO = checkScriptAIPromptService.getPrompt(type, scriptContent);
return getAIChatRecord(username, aiPromptDTO);
return getAIChatRecord(username, appId, aiPromptDTO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
import io.micrometer.core.instrument.util.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;

import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
Expand All @@ -72,12 +74,13 @@ public ChatServiceImpl(AIChatHistoryService aiChatHistoryService,
}

@Override
public AIChatRecord chatWithAI(String username, String userInput) {
public AIChatRecord chatWithAI(String username, Long appId, String userInput) {
Long startTime = System.currentTimeMillis();
// 1.保存初始聊天记录
AIPromptDTO aiPromptDTO = new AIPromptDTO(null, userInput, userInput);
AIChatHistoryDTO aiChatHistoryDTO = aiChatHistoryService.buildAIChatHistoryDTO(
username,
appId,
startTime,
aiPromptDTO,
AIChatStatusEnum.INIT.getStatus(),
Expand Down Expand Up @@ -123,8 +126,11 @@ public StreamingResponseBody generateChatStream(String username, Long recordId)
currentChatHistoryDTO.getAiInput(),
consumerAndProducerPair.getConsumer()
);
Locale locale = LocaleContextHolder.getLocale();
log.debug("language={}", locale.getLanguage());
future.whenComplete((content, throwable) -> {
// 5.处理AI回复内容
LocaleContextHolder.setLocale(locale);
aiAnswerHandler.handleAIAnswer(recordId, content, throwable);
futureMap.remove(recordId);
aiAnswerStreamSynchronizer.triggerEndEvent(throwable);
Expand Down

0 comments on commit f347971

Please sign in to comment.