Skip to content

Commit

Permalink
🔧 refactor: Improve Agent Context & Minor Fixes (#5349)
Browse files Browse the repository at this point in the history
* refactor: Improve Context for Agents

* 🔧 fix: Safeguard against undefined properties in OpenAIClient response handling

* refactor: log error before re-throwing for original stack trace

* refactor: remove toolResource state from useFileHandling, allow svg files

* refactor: prevent verbose logs from axios errors when using actions

* refactor: add silent method recordTokenUsage in AgentClient

* refactor: streamline token count assignment in BaseClient

* refactor: enhance safety settings handling for Gemini 2.0 model

* fix: capabilities structure in MCPConnection

* refactor: simplify civic integrity threshold handling in GoogleClient and llm

* refactor: update token count retrieval method in BaseClient tests

* ci: fix test for svg
  • Loading branch information
danny-avila authored Jan 17, 2025
1 parent e309c6a commit b35a8b7
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 112 deletions.
26 changes: 23 additions & 3 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ const {
supportsBalanceCheck,
isAgentsEndpoint,
isParamEndpoint,
EModelEndpoint,
ErrorTypes,
Constants,
CacheKeys,
Time,
} = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
Expand Down Expand Up @@ -95,7 +97,7 @@ class BaseClient {
* @returns {number}
*/
getTokenCountForResponse(responseMessage) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage);
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', responseMessage);
}

/**
Expand All @@ -106,7 +108,7 @@ class BaseClient {
* @returns {Promise<void>}
*/
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
promptTokens,
completionTokens,
});
Expand Down Expand Up @@ -287,6 +289,9 @@ class BaseClient {
}

async handleTokenCountMap(tokenCountMap) {
if (this.clientName === EModelEndpoint.agents) {
return;
}
if (this.currentMessages.length === 0) {
return;
}
Expand Down Expand Up @@ -394,6 +399,21 @@ class BaseClient {
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
let payload = this.addInstructions(formattedMessages, _instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
if (this.clientName === EModelEndpoint.agents) {
const { dbMessages, editedIndices } = truncateToolCallOutputs(
orderedWithInstructions,
this.maxContextTokens,
this.getTokenCountForMessage.bind(this),
);

if (editedIndices.length > 0) {
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
for (const index of editedIndices) {
payload[index].content = dbMessages[index].content;
}
orderedWithInstructions = dbMessages;
}
}

let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
Expand Down Expand Up @@ -625,7 +645,7 @@ class BaseClient {
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = this.getTokenCount(completion);
completionTokens = responseMessage.tokenCount;
}

await this.recordTokenUsage({ promptTokens, completionTokens, usage });
Expand Down
28 changes: 19 additions & 9 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -886,32 +886,42 @@ class GoogleClient extends BaseClient {
}

getSafetySettings() {
const isGemini2 = this.modelOptions.model.includes('gemini-2.0');
const mapThreshold = (value) => {
if (isGemini2 && value === 'BLOCK_NONE') {
return 'OFF';
}
return value;
};

return [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold:
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold:
threshold: mapThreshold(
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
),
},
{
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
/**
* Note: this was added since `gemini-2.0-flash-thinking-exp-1219` does not
* accept 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' for 'HARM_CATEGORY_CIVIC_INTEGRITY'
* */
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE',
threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
},
];
}
Expand Down
2 changes: 1 addition & 1 deletion api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ ${convo}
});

for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || '';
const token = chunk?.choices?.[0]?.delta?.content || '';
intermediateReply.push(token);
onProgress(token);
if (abortController.signal.aborted) {
Expand Down
4 changes: 2 additions & 2 deletions api/app/clients/prompts/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const summaryPrompts = require('./summaryPrompts');
const handleInputs = require('./handleInputs');
const instructions = require('./instructions');
const titlePrompts = require('./titlePrompts');
const truncateText = require('./truncateText');
const truncate = require('./truncate');
const createVisionPrompt = require('./createVisionPrompt');
const createContextHandlers = require('./createContextHandlers');

Expand All @@ -15,7 +15,7 @@ module.exports = {
...handleInputs,
...instructions,
...titlePrompts,
...truncateText,
...truncate,
createVisionPrompt,
createContextHandlers,
};
115 changes: 115 additions & 0 deletions api/app/clients/prompts/truncate.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
const MAX_CHAR = 255;

/**
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
* if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
*/
function truncateText(text, maxLength = MAX_CHAR) {
if (text.length > maxLength) {
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
}
return text;
}

/**
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
* of ellipsis and notification if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
*/
function smartTruncateText(text, maxLength = MAX_CHAR) {
const ellipsis = '...';
const notification = ' [text truncated for brevity]';
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);

if (text.length > maxLength) {
const startLastHalf = text.length - halfMaxLength;
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
}

return text;
}

/**
* @param {TMessage[]} _messages
* @param {number} maxContextTokens
* @param {function({role: string, content: TMessageContent[]}): number} getTokenCountForMessage
*
* @returns {{
* dbMessages: TMessage[],
* editedIndices: number[]
* }}
*/
function truncateToolCallOutputs(_messages, maxContextTokens, getTokenCountForMessage) {
const THRESHOLD_PERCENTAGE = 0.5;
const targetTokenLimit = maxContextTokens * THRESHOLD_PERCENTAGE;

let currentTokenCount = 3;
const messages = [..._messages];
const processedMessages = [];
let currentIndex = messages.length;
const editedIndices = new Set();
while (messages.length > 0) {
currentIndex--;
const message = messages.pop();
currentTokenCount += message.tokenCount;
if (currentTokenCount < targetTokenLimit) {
processedMessages.push(message);
continue;
}

if (!message.content || !Array.isArray(message.content)) {
processedMessages.push(message);
continue;
}

const toolCallIndices = message.content
.map((item, index) => (item.type === 'tool_call' ? index : -1))
.filter((index) => index !== -1)
.reverse();

if (toolCallIndices.length === 0) {
processedMessages.push(message);
continue;
}

const newContent = [...message.content];

// Truncate all tool outputs since we're over threshold
for (const index of toolCallIndices) {
const toolCall = newContent[index].tool_call;
if (!toolCall || !toolCall.output) {
continue;
}

editedIndices.add(currentIndex);

newContent[index] = {
...newContent[index],
tool_call: {
...toolCall,
output: '[OUTPUT_OMITTED_FOR_BREVITY]',
},
};
}

const truncatedMessage = {
...message,
content: newContent,
tokenCount: getTokenCountForMessage({ role: 'assistant', content: newContent }),
};

processedMessages.push(truncatedMessage);
}

return { dbMessages: processedMessages.reverse(), editedIndices: Array.from(editedIndices) };
}

module.exports = { truncateText, smartTruncateText, truncateToolCallOutputs };
40 changes: 0 additions & 40 deletions api/app/clients/prompts/truncateText.js

This file was deleted.

4 changes: 2 additions & 2 deletions api/app/clients/specs/BaseClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ describe('BaseClient', () => {
test('getTokenCount for response is called with the correct arguments', async () => {
const tokenCountMap = {}; // Mock tokenCountMap
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap });
TestClient.getTokenCount = jest.fn();
TestClient.getTokenCountForResponse = jest.fn();
const response = await TestClient.sendMessage('Hello, world!', {});
expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text);
expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response);
});

test('returns an object with the correct shape', async () => {
Expand Down
2 changes: 2 additions & 0 deletions api/app/clients/tools/util/handleOpenAIErrors.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ async function handleOpenAIErrors(err, errorCallback, context = 'stream') {
logger.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`);
}

logger.error(err);

if (errorCallback) {
errorCallback(err);
}
Expand Down
Loading

0 comments on commit b35a8b7

Please sign in to comment.