Skip to content

Commit

Permalink
🌿 fix: forking a long conversation breaks chat structure (#4778)
Browse files Browse the repository at this point in the history
* fix: branching and forking sometimes break conversation structure

* fix test for forking.

* chore: message type issues

* test: add conversation structure tests for message handling

---------

Co-authored-by: xyqyear <[email protected]>
  • Loading branch information
danny-avila and xyqyear authored Nov 22, 2024
1 parent 7d5be68 commit c87a51e
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 16 deletions.
4 changes: 3 additions & 1 deletion api/models/Message.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,17 @@ async function saveMessage(req, params, metadata) {
* @async
* @function bulkSaveMessages
* @param {Object[]} messages - An array of message objects to save.
* @param {boolean} [overrideTimestamp=false] - Indicates whether to override the timestamps of the messages. Defaults to false.
* @returns {Promise<Object>} The result of the bulk write operation.
* @throws {Error} If there is an error in saving messages in bulk.
*/
async function bulkSaveMessages(messages) {
async function bulkSaveMessages(messages, overrideTimestamp=false) {
try {
const bulkOps = messages.map((message) => ({
updateOne: {
filter: { messageId: message.messageId },
update: message,
timestamps: !overrideTimestamp,
upsert: true,
},
}));
Expand Down
223 changes: 223 additions & 0 deletions api/models/convoStructure.spec.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Message, getMessages, bulkSaveMessages } = require('./Message');

// Original version of buildTree function
function buildTree({ messages, fileMap }) {
if (messages === null) {
return null;
}

const messageMap = {};
const rootMessages = [];
const childrenCount = {};

messages.forEach((message) => {
const parentId = message.parentMessageId ?? '';
childrenCount[parentId] = (childrenCount[parentId] || 0) + 1;

const extendedMessage = {
...message,
children: [],
depth: 0,
siblingIndex: childrenCount[parentId] - 1,
};

if (message.files && fileMap) {
extendedMessage.files = message.files.map((file) => fileMap[file.file_id ?? ''] ?? file);
}

messageMap[message.messageId] = extendedMessage;

const parentMessage = messageMap[parentId];
if (parentMessage) {
parentMessage.children.push(extendedMessage);
extendedMessage.depth = parentMessage.depth + 1;
} else {
rootMessages.push(extendedMessage);
}
});

return rootMessages;
}

let mongod;

beforeAll(async () => {
mongod = await MongoMemoryServer.create();
const uri = mongod.getUri();
await mongoose.connect(uri);
});

afterAll(async () => {
await mongoose.disconnect();
await mongod.stop();
});

beforeEach(async () => {
await Message.deleteMany({});
});

describe('Conversation Structure Tests', () => {
test('Conversation folding/corrupting with inconsistent timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';

// Create messages with inconsistent timestamps
const messages = [
{
messageId: 'message0',
parentMessageId: null,
text: 'Message 0',
createdAt: new Date('2023-01-01T00:00:00Z'),
},
{
messageId: 'message1',
parentMessageId: 'message0',
text: 'Message 1',
createdAt: new Date('2023-01-01T00:02:00Z'),
},
{
messageId: 'message2',
parentMessageId: 'message1',
text: 'Message 2',
createdAt: new Date('2023-01-01T00:01:00Z'),
}, // Note: Earlier than its parent
{
messageId: 'message3',
parentMessageId: 'message1',
text: 'Message 3',
createdAt: new Date('2023-01-01T00:03:00Z'),
},
{
messageId: 'message4',
parentMessageId: 'message2',
text: 'Message 4',
createdAt: new Date('2023-01-01T00:04:00Z'),
},
];

// Add common properties to all messages
messages.forEach((msg) => {
msg.conversationId = conversationId;
msg.user = userId;
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});

// Save messages with overrideTimestamp omitted (default is false)
await bulkSaveMessages(messages, true);

// Retrieve messages (this will sort by createdAt)
const retrievedMessages = await getMessages({ conversationId, user: userId });

// Build tree
const tree = buildTree({ messages: retrievedMessages });

// Check if the tree is incorrect (folded/corrupted)
expect(tree.length).toBeGreaterThan(1); // Should have multiple root messages, indicating corruption
});

test('Fix: Conversation structure maintained with more than 16 messages', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';

// Create more than 16 messages
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 500000 : -i * 500000)),
}));

// Save messages with new timestamps being generated (message objects ignored)
await bulkSaveMessages(messages);

// Retrieve messages (this will sort by createdAt, but it shouldn't matter now)
const retrievedMessages = await getMessages({ conversationId, user: userId });

// Build tree
const tree = buildTree({ messages: retrievedMessages });

// Check if the tree is correct
expect(tree.length).toBe(1); // Should have only one root message
let currentNode = tree[0];
for (let i = 1; i < 20; i++) {
expect(currentNode.children.length).toBe(1);
currentNode = currentNode.children[0];
expect(currentNode.text).toBe(`Message ${i}`);
}
expect(currentNode.children.length).toBe(0); // Last message should have no children
});

test('Simulate MongoDB ordering issue with more than 16 messages and close timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';

// Create more than 16 messages with very close timestamps
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 1 : -i * 1)),
}));

// Add common properties to all messages
messages.forEach((msg) => {
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});

await bulkSaveMessages(messages, true);
const retrievedMessages = await getMessages({ conversationId, user: userId });
const tree = buildTree({ messages: retrievedMessages });
expect(tree.length).toBeGreaterThan(1);
});

test('Fix: Preserve order with more than 16 messages by maintaining original timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';

// Create more than 16 messages with distinct timestamps
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + i * 1000), // Ensure each message has a distinct timestamp
}));

// Add common properties to all messages
messages.forEach((msg) => {
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});

// Save messages with overriding timestamps (preserve original timestamps)
await bulkSaveMessages(messages, true);

// Retrieve messages (this will sort by createdAt)
const retrievedMessages = await getMessages({ conversationId, user: userId });

// Build tree
const tree = buildTree({ messages: retrievedMessages });

// Check if the tree is correct
expect(tree.length).toBe(1); // Should have only one root message
let currentNode = tree[0];
for (let i = 1; i < 20; i++) {
expect(currentNode.children.length).toBe(1);
currentNode = currentNode.children[0];
expect(currentNode.text).toBe(`Message ${i}`);
}
expect(currentNode.children.length).toBe(0); // Last message should have no children
});
});
8 changes: 4 additions & 4 deletions api/server/utils/import/fork.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
),
), true,
);
});

Expand All @@ -122,7 +122,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
),
), true,
);
});

Expand All @@ -141,7 +141,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
),
), true,
);
});

Expand All @@ -160,7 +160,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
),
), true,
);
});

Expand Down
2 changes: 1 addition & 1 deletion api/server/utils/import/importBatchBuilder.js
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ImportBatchBuilder {
async saveBatch() {
try {
await bulkSaveConvos(this.conversations);
await bulkSaveMessages(this.messages);
await bulkSaveMessages(this.messages, true);
logger.debug(
`user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`,
);
Expand Down
4 changes: 2 additions & 2 deletions client/src/components/Chat/Messages/SearchMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ export default function Message({ message }: Pick<TMessageProps, 'message'>) {
let messageLabel = '';
if (isCreatedByUser) {
messageLabel = UsernameDisplay
? (user?.name ?? '') || user?.username
? (user?.name ?? '') || (user?.username ?? '')
: localize('com_user_message');
} else {
messageLabel = message.sender;
messageLabel = message.sender || '';
}

return (
Expand Down
15 changes: 11 additions & 4 deletions client/src/components/Share/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@ export default function Message(props: TMessageProps) {
return null;
}

const { text, children, messageId = null, isCreatedByUser, error, unfinished } = message ?? {};
const {
text = '',
children,
messageId = null,
isCreatedByUser = true,
error = false,
unfinished = false,
} = message;

let messageLabel = '';
if (isCreatedByUser) {
messageLabel = 'anonymous';
} else {
messageLabel = message.sender;
messageLabel = message.sender || '';
}

return (
Expand Down Expand Up @@ -67,12 +74,12 @@ export default function Message(props: TMessageProps) {
error={error}
isLast={false}
ask={() => ({})}
text={text ?? ''}
text={text}
message={message}
isSubmitting={false}
enterEdit={() => ({})}
unfinished={!!unfinished}
isCreatedByUser={isCreatedByUser ?? true}
isCreatedByUser={isCreatedByUser}
siblingIdx={siblingIdx ?? 0}
setSiblingIdx={setSiblingIdx ?? (() => ({}))}
/>
Expand Down
4 changes: 2 additions & 2 deletions client/src/hooks/Conversations/useExportConversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ export default function useExportConversation({
};

if (!message.content) {
return formatText(message.sender, message.text);
return formatText(message.sender || '', message.text);
}

return message.content
.map((content) => getMessageContent(message.sender, content))
.map((content) => getMessageContent(message.sender || '', content))
.map((text) => {
return formatText(text[0], text[1]);
})
Expand Down
4 changes: 2 additions & 2 deletions packages/data-provider/src/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ export const tMessageSchema = z.object({
bg: z.string().nullable().optional(),
model: z.string().nullable().optional(),
title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'),
sender: z.string(),
sender: z.string().optional(),
text: z.string(),
generation: z.string().nullable().optional(),
isEdited: z.boolean().optional(),
isCreatedByUser: z.boolean(),
error: z.boolean(),
error: z.boolean().optional(),
createdAt: z
.string()
.optional()
Expand Down

0 comments on commit c87a51e

Please sign in to comment.