Skip to content

Commit

Permalink
working version by patching deserializedr
Browse files Browse the repository at this point in the history
  • Loading branch information
sabrenner committed Jan 24, 2025
1 parent 29c26b0 commit f212b42
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
57 changes: 57 additions & 0 deletions packages/datadog-instrumentations/src/aws-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,60 @@ addHook({ name: 'aws-sdk', file: 'lib/core.js', versions: ['>=2.1.35'] }, AWS =>
shimmer.wrap(AWS.Request.prototype, 'send', wrapRequest)
return AWS
})
// hooks for bedrock model token counts
// later to add: converse, streamed
const commands = new Set(['InvokeModelCommand'])

function wrapBedrockCommandDeserialize (deserialize) {
return function (response) {
const tokenCh = channel('apm:aws:token:bedrockruntime')

const requestId = response.headers['x-amzn-requestid']
const inputTokenCount = response.headers['x-amzn-bedrock-input-token-count']
const outputTokenCount = response.headers['x-amzn-bedrock-output-token-count']

tokenCh.publish({ requestId, inputTokenCount, outputTokenCount })

return deserialize.apply(this, arguments)
}
}

/**
* TL;DR we want to access the deserialize middleware to intercept the headers before they are stripped from
* the response. This deserialize function is located in different place for different versions of bedrock
*/
addHook({
name: '@aws-sdk/client-bedrock-runtime',
versions: ['>=3.422.0']
}, BedrockRuntime => {
for (const command of commands) {
const Command = BedrockRuntime[command]
shimmer.wrap(Command.prototype, 'deserialize', wrapBedrockCommandDeserialize)
}
return BedrockRuntime
})

// duplicate hook for now
addHook({
name: '@smithy/smithy-client',
versions: ['>=1.0.3']
}, client => {
shimmer.wrap(client.Command, 'classBuilder', classBuilder => {
return function () {
const builder = classBuilder.apply(this, arguments)
shimmer.wrap(builder, 'de', de => {
return function () {
const deserializerName = arguments[0]?.name?.split('de_')[1]
if (commands.has(deserializerName)) {
const originalDeserializer = arguments[0]
arguments[0] = shimmer.wrapFunction(originalDeserializer, wrapBedrockCommandDeserialize)
}

return de.apply(this, arguments)
}
})
return builder
}
})
return client
})
1 change: 1 addition & 0 deletions packages/datadog-instrumentations/src/helpers/hooks.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module.exports = {
'@apollo/gateway': () => require('../apollo'),
'apollo-server-core': () => require('../apollo-server-core'),
'@aws-sdk/smithy-client': () => require('../aws-sdk'),
'@aws-sdk/client-bedrock-runtime': () => require('../aws-sdk'),
'@azure/functions': () => require('../azure-functions'),
'@cucumber/cucumber': () => require('../cucumber'),
'@playwright/test': () => require('../playwright'),
Expand Down
32 changes: 32 additions & 0 deletions packages/dd-trace/src/llmobs/plugins/bedrockruntime.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ const {

const enabledOperations = ['invokeModel']

const requestIdsToTokens = {}

class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
constructor () {
super(...arguments)
Expand All @@ -30,6 +32,13 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
const span = storage.getStore()?.span
this.setLLMObsTags({ request, span, response, modelProvider, modelName })
})

this.addSub('apm:aws:token:bedrockruntime', ({ requestId, inputTokenCount, outputTokenCount }) => {
requestIdsToTokens[requestId] = {
inputTokenCount,
outputTokenCount
}
})
}

setLLMObsTags ({ request, span, response, modelProvider, modelName }) {
Expand All @@ -53,6 +62,29 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {

// add I/O tags
this._tagger.tagLLMIO(span, requestParams.prompt, textAndResponseReason.message)

// add token metrics
const { inputTokens, outputTokens, totalTokens } = this.extractTokens({ response })
this._tagger.tagMetrics(span, {
inputTokens,
outputTokens,
totalTokens
})
}

extractTokens ({ response }) {
const requestId = response.$metadata.requestId
const { inputTokenCount, outputTokenCount } = requestIdsToTokens[requestId] || {}
delete requestIdsToTokens[requestId]

const inputTokens = parseInt(inputTokenCount) || 0
const outputTokens = parseInt(outputTokenCount) || 0

return {
inputTokens,
outputTokens,
totalTokens: inputTokens + outputTokens
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ describe('Plugin', () => {

nock('http://127.0.0.1:4566')
.post(`/model/${model.modelId}/invoke`)
.reply(200, response)
.reply(200, response, {
'x-amzn-bedrock-input-token-count': 50,
'x-amzn-bedrock-output-token-count': 70,
'x-amzn-requestid': Date.now().toString()
})

const command = new AWS.InvokeModelCommand(request)

Expand All @@ -93,7 +97,7 @@ describe('Plugin', () => {
{ content: model.userPrompt }
],
outputMessages: MOCK_ANY,
tokenMetrics: { input_tokens: 0, output_tokens: 0, total_tokens: 0 },
tokenMetrics: { input_tokens: 50, output_tokens: 70, total_tokens: 120 },
modelName: model.modelId.split('.')[1].toLowerCase(),
modelProvider: model.provider.toLowerCase(),
metadata: {
Expand Down

0 comments on commit f212b42

Please sign in to comment.