Skip to content

Commit

Permalink
[Obs AI Assistant] Register and retrieve get_data_on_screen function …
Browse files Browse the repository at this point in the history
…description as an adhoc instruction (elastic#184214)
  • Loading branch information
viduni94 committed Oct 30, 2024
1 parent 0ce486a commit 9b6cbae
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
systemMessage: getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions: [],
adHocInstructions: functionClient.getAdhocInstructions(),
availableFunctionNames,
}),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dedent from 'dedent';
import { ChatFunctionClient, GET_DATA_ON_SCREEN_FUNCTION_NAME } from '.';
import { FunctionVisibility } from '../../../common/functions/types';
import { AdHocInstruction } from '../../../common/types';

describe('chatFunctionClient', () => {
describe('when executing a function with invalid arguments', () => {
Expand Down Expand Up @@ -86,6 +87,7 @@ describe('chatFunctionClient', () => {
]);

const functions = client.getFunctions();
const adHocInstructions = client.getAdhocInstructions();

expect(functions[0]).toEqual({
definition: {
Expand All @@ -97,7 +99,7 @@ describe('chatFunctionClient', () => {
respond: expect.any(Function),
});

expect(functions[0].definition.description).toContain(
expect(adHocInstructions[0].text).toContain(
dedent(`my_dummy_data: My dummy data
my_other_dummy_data: My other dummy data
`)
Expand Down Expand Up @@ -128,4 +130,52 @@ describe('chatFunctionClient', () => {
});
});
});

describe('when adhoc instructions are provided', () => {
let client: ChatFunctionClient;

beforeEach(() => {
client = new ChatFunctionClient([]);
});

describe('register an adhoc Instruction', () => {
it('should register a new adhoc instruction', () => {
const adhocInstruction: AdHocInstruction = {
text: 'Test adhoc instruction',
instruction_type: 'application_instruction',
};

client.registerAdhocInstruction(adhocInstruction);

expect(client.getAdhocInstructions()).toContainEqual(adhocInstruction);
});
});

describe('retrieve adHoc instructions', () => {
it('should return all registered adhoc instructions', () => {
const firstAdhocInstruction: AdHocInstruction = {
text: 'First adhoc instruction',
instruction_type: 'application_instruction',
};

const secondAdhocInstruction: AdHocInstruction = {
text: 'Second adhoc instruction',
instruction_type: 'application_instruction',
};

client.registerAdhocInstruction(firstAdhocInstruction);
client.registerAdhocInstruction(secondAdhocInstruction);

const adhocInstructions = client.getAdhocInstructions();

expect(adhocInstructions).toEqual([firstAdhocInstruction, secondAdhocInstruction]);
});

it('should return an empty array if no adhoc instructions are registered', () => {
const adhocInstructions = client.getAdhocInstructions();

expect(adhocInstructions).toEqual([]);
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ import Ajv, { type ErrorObject, type ValidateFunction } from 'ajv';
import dedent from 'dedent';
import { compact, keyBy } from 'lodash';
import { FunctionVisibility, type FunctionResponse } from '../../../common/functions/types';
import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types';
import type {
AdHocInstruction,
Message,
ObservabilityAIAssistantScreenContextRequest,
} from '../../../common/types';
import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions';
import type {
FunctionCallChatFunction,
FunctionHandler,
FunctionHandlerRegistry,
InstructionOrCallback,
RegisterAdHocInstruction,
RegisterFunction,
RegisterInstruction,
} from '../types';
Expand All @@ -35,6 +40,8 @@ export const GET_DATA_ON_SCREEN_FUNCTION_NAME = 'get_data_on_screen';

export class ChatFunctionClient {
private readonly instructions: InstructionOrCallback[] = [];
private readonly adhocInstructions: AdHocInstruction[] = [];

private readonly functionRegistry: FunctionHandlerRegistry = new Map();
private readonly validators: Map<string, ValidateFunction> = new Map();

Expand All @@ -49,9 +56,7 @@ export class ChatFunctionClient {
this.registerFunction(
{
name: GET_DATA_ON_SCREEN_FUNCTION_NAME,
description: dedent(`Get data that is on the screen:
${allData.map((data) => `${data.name}: ${data.description}`).join('\n')}
`),
description: dedent('Get data that is on the screen'),
visibility: FunctionVisibility.AssistantOnly,
parameters: {
type: 'object',
Expand All @@ -75,6 +80,13 @@ export class ChatFunctionClient {
};
}
);

this.registerAdhocInstruction({
text: `The ${GET_DATA_ON_SCREEN_FUNCTION_NAME} function will ${dedent(`Get data that is on the screen:
${allData.map((data) => `${data.name}: ${data.description}`).join('\n')}
`)}`,
instruction_type: 'application_instruction',
});
}

this.actions.forEach((action) => {
Expand All @@ -95,6 +107,10 @@ export class ChatFunctionClient {
this.instructions.push(instruction);
};

registerAdhocInstruction: RegisterAdHocInstruction = (instruction: AdHocInstruction) => {
this.adhocInstructions.push(instruction);
};

validate(name: string, parameters: unknown) {
const validator = this.validators.get(name)!;
if (!validator) {
Expand All @@ -111,6 +127,10 @@ export class ChatFunctionClient {
return this.instructions;
}

getAdhocInstructions(): AdHocInstruction[] {
return this.adhocInstructions;
}

hasAction(name: string) {
return !!this.actions.find((action) => action.name === name)!;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import { eventsourceStreamIntoObservable } from '../../util/eventsource_stream_i
import { LlmApiAdapterFactory } from './types';
import { parseInlineFunctionCalls } from './simulate_function_calling/parse_inline_function_calls';
import { getMessagesWithSimulatedFunctionCalling } from './simulate_function_calling/get_messages_with_simulated_function_calling';
import { GET_DATA_ON_SCREEN_FUNCTION_NAME } from '../../chat_function_client';

function getOpenAIPromptTokenCount({
messages,
Expand Down Expand Up @@ -124,10 +123,7 @@ export const createOpenAiAdapter: LlmApiAdapterFactory = ({
...(!!functionsForOpenAI?.length
? {
tools: functionsForOpenAI.map((fn) => ({
function:
fn.name === GET_DATA_ON_SCREEN_FUNCTION_NAME
? pick(fn, 'name', 'parameters')
: pick(fn, 'name', 'description', 'parameters'),
function: pick(fn, 'name', 'description', 'parameters'),
type: 'function',
})),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ describe('Observability AI Assistant client', () => {
getActions: jest.fn(),
validate: jest.fn(),
getInstructions: jest.fn(),
getAdhocInstructions: jest.fn(),
} as any;

let llmSimulator: LlmSimulator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ import {
} from '../../../common/conversation_complete';
import { CompatibleJSONSchema } from '../../../common/functions/types';
import {
AdHocInstruction,
type Conversation,
type ConversationCreateRequest,
type ConversationUpdateRequest,
type KnowledgeBaseEntry,
type Message,
type AdHocInstruction,
} from '../../../common/types';
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
Expand Down Expand Up @@ -210,6 +210,9 @@ export class ObservabilityAIAssistantClient {

const userInstructions$ = from(this.getKnowledgeBaseUserInstructions()).pipe(shareReplay());

const registeredAdhocInstructions = functionClient.getAdhocInstructions();
const allAdHocInstructions = adHocInstructions.concat(registeredAdhocInstructions);

// from the initial messages, override any system message with
// the one that is based on the instructions (registered, request, kb)
const messagesWithUpdatedSystemMessage$ = userInstructions$.pipe(
Expand All @@ -219,7 +222,7 @@ export class ObservabilityAIAssistantClient {
getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions,
adHocInstructions: allAdHocInstructions,
availableFunctionNames: functionClient
.getFunctions()
.map((fn) => fn.definition.name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ export function continueConversation({
chat,
signal,
functionCallsLeft,
adHocInstructions,
adHocInstructions = [],
userInstructions,
logger,
disableFunctions,
Expand Down Expand Up @@ -213,11 +213,14 @@ export function continueConversation({
disableFunctions,
});

const registeredAdhocInstructions = functionClient.getAdhocInstructions();
const allAdHocInstructions = adHocInstructions.concat(registeredAdhocInstructions);

const messagesWithUpdatedSystemMessage = replaceSystemMessage(
getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions,
adHocInstructions: allAdHocInstructions,
availableFunctionNames: definitions.map((def) => def.name),
}),
initialMessages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type {
Message,
ObservabilityAIAssistantScreenContextRequest,
InstructionOrPlainText,
AdHocInstruction,
} from '../../common/types';
import type { ObservabilityAIAssistantRouteHandlerResources } from '../routes/types';
import { ChatFunctionClient } from './chat_function_client';
Expand Down Expand Up @@ -76,6 +77,8 @@ export type RegisterInstructionCallback = ({

export type RegisterInstruction = (...instruction: InstructionOrCallback[]) => void;

export type RegisterAdHocInstruction = (...instruction: AdHocInstruction[]) => void;

export type RegisterFunction = <
TParameters extends CompatibleJSONSchema = any,
TResponse extends FunctionResponse = any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function getSystemMessageFromInstructions({

const adHocInstructionsWithId = adHocInstructions.map((adHocInstruction) => ({
...adHocInstruction,
doc_id: adHocInstruction.doc_id ?? v4(),
doc_id: adHocInstruction?.doc_id ?? v4(),
}));

// split ad hoc instructions into user instructions and application instructions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ If available, include the link of the conversation at the end of your answer.`
availableFunctionNames: functionClient.getFunctions().map((fn) => fn.definition.name),
applicationInstructions: functionClient.getInstructions(),
userInstructions: [],
adHocInstructions: [],
adHocInstructions: functionClient.getAdhocInstructions(),
}),
},
},
Expand Down

0 comments on commit 9b6cbae

Please sign in to comment.