diff --git a/plugins/bedrock/bedrock.test.ts b/plugins/bedrock/bedrock.test.ts index 17922dc3..98eea010 100644 --- a/plugins/bedrock/bedrock.test.ts +++ b/plugins/bedrock/bedrock.test.ts @@ -1,6 +1,6 @@ -import { PluginContext, PluginParameters } from '../types'; +import { HookEventType, PluginContext, PluginParameters } from '../types'; import { pluginHandler } from './index'; -import creds from './.creds.json'; +import testCreds from './.creds.json'; import { BedrockParameters } from './type'; /** @@ -19,7 +19,18 @@ import { BedrockParameters } from './type'; describe('Credentials check', () => { test('Should fail withuout accessKey or accessKeySecret', async () => { const context = { - request: { text: 'this is a test string for moderations' }, + request: { + text: 'My email is abc@xyz.com and SSN is 123-45-6789', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and SSN is 123-45-6789', + }, + ], + }, + }, + requestType: 'chatComplete', }; const parameters: PluginParameters = { credentials: { @@ -46,7 +57,18 @@ describe('Credentials check', () => { test('Should fail with wrong creds', async () => { const context = { - request: { text: 'this is a test string for moderations' }, + request: { + text: 'My email is abc@xyz.com and SSN is 123-45-6789', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and SSN is 123-45-6789', + }, + ], + }, + }, + requestType: 'chatComplete', }; const parameters: PluginParameters = { credentials: { @@ -71,119 +93,209 @@ describe('Credentials check', () => { expect(result.data).toBeNull(); }); - test('Should be working with word_filter', async () => { - // coding is a blocked custom word. + it('should only detect PII', async () => { + const eventType = 'beforeRequestHook' as HookEventType; const context = { - request: { text: `Hi, do you know coding?` }, + request: { + text: 'My email is abc@xyz.com and SSN is 123-45-6789', + json: { + messages: [ + { + role: 'user', + content: 'My email is abc@xyz.com and SSN is 123-45-6789', + }, + ], + }, + }, + requestType: 'chatComplete', }; - const parameters: PluginParameters = { - ...creds, + const parameters = { + credentials: testCreds, + guardrailId: testCreds.guardrailId, + guardrailVersion: testCreds.guardrailVersion, }; - const result = await pluginHandler.bind({ fn: 'wordFilter' })( - context as unknown as PluginContext, + const result = await pluginHandler( + context as PluginContext, parameters, - 'beforeRequestHook', - { env: {} } + eventType, + { + env: {}, + } ); - expect(result).toBeDefined(); expect(result.verdict).toBe(false); - expect(result.error).toBe(null); + expect(result.error).toBeNull(); expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json).toBeNull(); }); - test('Should be working with content_filter', async () => { - // `kill` is a word that should be blocked under `contentFilter` - // Violence words are not allowed. + it('should detect and redact PII in request text', async () => { const context = { - request: { text: `Can you kill a person?` }, + request: { + text: 'My SSN is 123-45-6789 and some random text', + json: { + messages: [ + { + role: 'user', + content: 'My SSN is 123-45-6789 and some random text', + }, + ], + }, + }, + requestType: 'chatComplete', }; - const parameters: PluginParameters = { - ...creds, + const parameters = { + credentials: testCreds, + redact: true, + guardrailId: testCreds.guardrailId, + guardrailVersion: testCreds.guardrailVersion, }; - const result = await pluginHandler.bind({ fn: 'contentFilter' })( - context as unknown as PluginContext, + const result = await pluginHandler( + context as PluginContext, parameters, 'beforeRequestHook', - { env: {} } + { + env: {}, + } ); + expect(result.error).toBeNull(); + expect(result.verdict).toBe(true); + expect(result.data).toBeDefined(); + expect(result.transformedData?.request?.json?.messages?.[0]?.content).toBe( + 'My SSN is {US_SOCIAL_SECURITY_NUMBER} and some random text' + ); + }); - expect(result).toBeDefined(); - expect(result.verdict).toBe(false); - expect(result.error).toBe(null); + it('should detect and redact PII in request text with multiple content parts', async () => { + const context = { + request: { + text: 'My SSN is 123-45-6789 My SSN is 123-45-6789 and some random text', + json: { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'My SSN is 123-45-6789', + }, + { + type: 'text', + text: 'My SSN is 123-45-6789 and some random text', + }, + ], + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + credentials: testCreds, + redact: true, + guardrailId: testCreds.guardrailId, + guardrailVersion: testCreds.guardrailVersion, + }; + + const result = await pluginHandler( + context as PluginContext, + parameters, + 'beforeRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(true); + expect(result.data).toBeDefined; + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[0]?.text + ).toBe('My SSN is {US_SOCIAL_SECURITY_NUMBER}'); + expect( + result.transformedData?.request?.json?.messages?.[0]?.content?.[1]?.text + ).toBe('My SSN is {US_SOCIAL_SECURITY_NUMBER} and some random text'); + }); + + it('should detect and redact PII in response text', async () => { + const context = { + response: { + text: 'My SSN is 123-45-6789 and some random text', + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'My SSN is 123-45-6789 and some random text', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + credentials: testCreds, + redact: true, + guardrailId: testCreds.guardrailId, + guardrailVersion: testCreds.guardrailVersion, + }; + + const result = await pluginHandler( + context as PluginContext, + parameters, + 'afterRequestHook', + { + env: {}, + } + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(true); expect(result.data).toBeDefined(); + expect( + result.transformedData?.response?.json?.choices?.[0]?.message?.content + ).toBe('My SSN is {US_SOCIAL_SECURITY_NUMBER} and some random text'); }); - // test('Should work fine with redaction for sensitive info', async () => { - // const context = { - // response: { - // json: { - // choices: [ - // { - // message: { - // content: - // 'Hello, John doe. How are you doing?. I see your email is john@doe.com', - // }, - // }, - // ], - // }, - // }, - // requestType: 'chatComplete', - // }; - - // const parameters: PluginParameters = { - // ...creds, - // }; - - // const result = await bedrockPIIHandler( - // context as unknown as PluginContext, - // parameters, - // 'afterRequestHook', - // { env: {} } - // ); - - // const outputMessage = - // result.transformedData?.response.json.choices[0].message.content; - // expect(result).toBeDefined(); - // expect(result.verdict).toBe(true); - // expect(outputMessage).toEqual( - // 'Hello, {NAME}. How are you doing?. I see your email is {EMAIL}\n' - // ); - // }); - - // test('Should work fine with regex redaction for sensitive info', async () => { - // const context = { - // response: { - // json: { - // choices: [ - // { - // message: { - // content: 'bedrock-12121, bedrock-12121', - // }, - // }, - // ], - // }, - // }, - // requestType: 'chatComplete', - // }; - - // const parameters: PluginParameters = { - // ...creds, - // }; - - // const result = await bedrockPIIHandler( - // context as unknown as PluginContext, - // parameters, - // 'afterRequestHook', - // { env: {} } - // ); - - // const outputMessage = - // result.transformedData?.response.json.choices[0].message.content; - // expect(result).toBeDefined(); - // expect(result.verdict).toBe(true); - // expect(outputMessage).toBe('{bedrock-id}, {bedrock-id}\n'); - // }); + it('should pass text without PII', async () => { + const eventType = 'afterRequestHook' as HookEventType; + const context = { + response: { + text: 'Hello world', + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'Hello world', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + }; + const parameters = { + credentials: testCreds, + guardrailId: testCreds.guardrailId, + guardrailVersion: testCreds.guardrailVersion, + }; + + const result = await pluginHandler( + context as PluginContext, + parameters, + eventType, + { + env: {}, + } + ); + expect(result).toBeDefined(); + expect(result.verdict).toBe(true); + expect(result.error).toBeNull(); + expect(result.data).toBeDefined(); + expect(result.transformedData?.response?.json).toBeNull(); + }); }); diff --git a/plugins/bedrock/index.ts b/plugins/bedrock/index.ts index f32b778e..cff6c78a 100644 --- a/plugins/bedrock/index.ts +++ b/plugins/bedrock/index.ts @@ -1,7 +1,6 @@ -import { HookEventType, PluginContext, PluginHandler } from '../types'; +import { PluginHandler } from '../types'; import { getCurrentContentPart, - getText, HttpError, setCurrentContentPart, } from '../utils'; @@ -18,46 +17,24 @@ export const validateCreds = ( ); }; -const transformedData = { - request: { - json: null, - }, - response: { - json: null, - }, -}; - -const handleRedaction = async ( - context: PluginContext, - hookType: HookEventType, - credentials: Record -) => { - const { content, textArray } = getCurrentContentPart(context, hookType); - - if (!content) { - return []; - } - const redactPromises = textArray.map(async (text) => { - const result = await redactPii(text, hookType, credentials); - - if (result) { - setCurrentContentPart(context, hookType, transformedData, result); - } - }); - - await Promise.all(redactPromises); -}; - export const pluginHandler: PluginHandler< BedrockParameters['credentials'] > = async (context, parameters, eventType) => { + const transformedData: Record = { + request: { + json: null, + }, + response: { + json: null, + }, + }; const credentials = parameters.credentials; const validate = validateCreds(credentials); const guardrailVersion = parameters.guardrailVersion; const guardrailId = parameters.guardrailId; - const pii = parameters?.piiCheck as boolean; + const redact = parameters?.redact as boolean; let verdict = true; let error = null; @@ -65,21 +42,11 @@ export const pluginHandler: PluginHandler< if (!validate || !guardrailVersion || !guardrailId) { return { verdict, - error: 'Missing required credentials', + error: { message: 'Missing required credentials' }, data, }; } - if (pii) { - await handleRedaction(context, eventType, { - ...credentials, - guardrailId, - guardrailVersion, - }); - - return { error, data, verdict: true, transformedData }; - } - const body = {} as BedrockBody; if (eventType === 'beforeRequestHook') { @@ -88,37 +55,90 @@ export const pluginHandler: PluginHandler< body.source = 'OUTPUT'; } - body.content = [ - { - text: { - text: getText(context, eventType), - }, - }, - ]; - try { - const response = await bedrockPost( - { ...(credentials as any), guardrailId, guardrailVersion }, - body + const { content, textArray } = getCurrentContentPart(context, eventType); + + if (!content) { + return { + error: { message: 'request or response json is empty' }, + verdict: true, + data: null, + transformedData, + }; + } + + const results = await Promise.all( + textArray.map((text) => + text + ? bedrockPost( + { ...(credentials as any), guardrailId, guardrailVersion }, + { + content: [{ text: { text } }], + source: body.source, + } + ) + : null + ) ); - if (response.action === 'GUARDRAIL_INTERVENED') { - verdict = false; - // Send assessments - data = response.assessments[0] as any; - delete data['invocationMetrics']; - delete data['usage']; + const interventionData = + results.find( + (result) => result && result.action === 'GUARDRAIL_INTERVENED' + ) ?? results[0]; + + const flaggedCategories = new Set(); + + results.forEach((result) => { + if (!result) return; + if (result.assessments[0].contentPolicy?.filters?.length > 0) { + flaggedCategories.add('contentFilter'); + } + if (result.assessments[0].wordPolicy?.customWords?.length > 0) { + flaggedCategories.add('wordFilter'); + } + if (result.assessments[0].wordPolicy?.managedWordLists?.length > 0) { + flaggedCategories.add('wordFilter'); + } + if ( + result.assessments[0].sensitiveInformationPolicy?.piiEntities?.length > + 0 + ) { + flaggedCategories.add('piiFilter'); + } + }); + + let hasPii = flaggedCategories.has('piiFilter'); + if (hasPii && redact) { + const maskedTexts = textArray.map((text, index) => + redactPii(text, results[index]) + ); + + setCurrentContentPart( + context, + eventType, + transformedData, + null, + maskedTexts + ); + } + + if (hasPii && flaggedCategories.size === 1 && redact) { + verdict = true; + } else if (flaggedCategories.size > 0) { + verdict = false; } + data = interventionData; } catch (e) { if (e instanceof HttpError) { - error = e.response.body; + error = { message: e.response.body }; } else { - error = (e as Error).message; + error = { message: (e as Error).message }; } } return { verdict, error, data, + transformedData, }; }; diff --git a/plugins/bedrock/util.ts b/plugins/bedrock/util.ts index fcd82275..2643d54d 100644 --- a/plugins/bedrock/util.ts +++ b/plugins/bedrock/util.ts @@ -1,13 +1,7 @@ import { Sha256 } from '@aws-crypto/sha256-js'; import { SignatureV4 } from '@smithy/signature-v4'; -import { - BedrockBody, - BedrockParameters, - BedrockResponse, - PIIFilter, -} from './type'; +import { BedrockBody, BedrockResponse, PIIFilter } from './type'; import { post } from '../utils'; -import { HookEventType } from '../types'; export const generateAWSHeaders = async ( body: Record, @@ -98,44 +92,27 @@ const replaceMatches = ( * @param credentials * @returns */ -export const redactPii = async ( - text: string, - eventType: HookEventType, - credentials: Record -) => { - const body = {} as BedrockBody; - - if (eventType === 'beforeRequestHook') { - body.source = 'INPUT'; - } else { - body.source = 'OUTPUT'; - } - - body.content = [ - { - text: { - text, - }, - }, - ]; - +export const redactPii = (text: string, result: BedrockResponse | null) => { try { - const response = await bedrockPost({ ...(credentials as any) }, body); + if (!result) return null; + if (!result.assessments[0]?.sensitiveInformationPolicy || !text) { + return null; + } // `ANONYMIZED` means text is already masked by api invokation const isMasked = - response.assessments[0].sensitiveInformationPolicy.piiEntities?.find( + result.assessments[0].sensitiveInformationPolicy.piiEntities?.find( (entity) => entity.action === 'ANONYMIZED' ); - let maskedText = text; + let maskedText: string = text; if (isMasked) { // Use the invoked text directly. - const data = response.output?.[0]; + const data = result.output?.[0]; maskedText = data?.text; } else { // Replace the all entires of each filter sent from api. - response.assessments[0].sensitiveInformationPolicy.piiEntities.forEach( + result.assessments[0].sensitiveInformationPolicy.piiEntities.forEach( (filter) => { maskedText = replaceMatches(filter, maskedText, false); } @@ -144,9 +121,9 @@ export const redactPii = async ( // Replace the all entires of each filter sent from api for regex const isRegexMatch = - response.assessments[0].sensitiveInformationPolicy?.regexes?.length > 0; + result.assessments[0].sensitiveInformationPolicy?.regexes?.length > 0; if (isRegexMatch) { - response.assessments[0].sensitiveInformationPolicy.regexes.forEach( + result.assessments[0].sensitiveInformationPolicy.regexes.forEach( (regex) => { maskedText = replaceMatches(regex as any, maskedText, true); }