Skip to content

Commit

Permalink
Merge pull request #1630 from gchq/feature/import-model-cards-auditing
Browse files Browse the repository at this point in the history
Allow for auditing of new model cards during import
  • Loading branch information
ARADDCC012 authored Nov 25, 2024
2 parents daad354 + 0cc6435 commit 5f6b52d
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 111 deletions.
5 changes: 3 additions & 2 deletions backend/src/connectors/audit/Base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,18 @@ export abstract class BaseAuditConnector {
abstract onCreateS3Export(req: Request, modelId: string, semvers?: string[])
abstract onCreateImport(
req: Request,
mirroredModelId: string,
mirroredModel: ModelInterface,
sourceModelId: string,
modelCardVersions: number[],
exporter: string,
newModelCards: ModelCardInterface[],
)

abstract onError(req: Request, error: BailoError)

checkEventType(auditInfo: AuditInfoKeys, req: Request) {
if (auditInfo.typeId !== req.audit.typeId && auditInfo.description !== req.audit.description) {
throw new Error(`Audit: Expected type '${JSON.stringify(auditInfo)}' but recieved '${JSON.stringify(req.audit)}'`)
throw new Error(`Audit: Expected type '${JSON.stringify(auditInfo)}' but received '${JSON.stringify(req.audit)}'`)
}
}
}
4 changes: 3 additions & 1 deletion backend/src/connectors/audit/silly.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AccessRequestDoc } from '../../models/AccessRequest.js'
import { FileInterface, FileInterfaceDoc } from '../../models/File.js'
import { InferenceDoc } from '../../models/Inference.js'
import { ModelCardInterface, ModelDoc, ModelInterface } from '../../models/Model.js'
import { ModelCardRevisionInterface } from '../../models/ModelCardRevision.js'
import { ReleaseDoc } from '../../models/Release.js'
import { ResponseInterface } from '../../models/Response.js'
import { ReviewInterface } from '../../models/Review.js'
Expand Down Expand Up @@ -59,10 +60,11 @@ export class SillyAuditConnector extends BaseAuditConnector {
onCreateS3Export(_req: Request, _modelId: string, _semvers?: string[]) {}
onCreateImport(
_req: Request,
_mirroredModelId: string,
_mirroredModel: ModelInterface,
_sourceModelId: string,
_modelCardVersions: number[],
_exporter: string,
_newModelCards: ModelCardRevisionInterface[],
) {}
onError(_req: Request, _error: BailoError) {}
onCreateCommentResponse(_req: Request, _responseInterface: ResponseInterface) {}
Expand Down
6 changes: 4 additions & 2 deletions backend/src/connectors/audit/stdout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AccessRequestDoc } from '../../models/AccessRequest.js'
import { FileInterface, FileInterfaceDoc } from '../../models/File.js'
import { InferenceDoc } from '../../models/Inference.js'
import { ModelCardInterface, ModelDoc, ModelInterface } from '../../models/Model.js'
import { ModelCardRevisionInterface } from '../../models/ModelCardRevision.js'
import { ReleaseDoc } from '../../models/Release.js'
import { ResponseInterface } from '../../models/Response.js'
import { ReviewInterface } from '../../models/Review.js'
Expand Down Expand Up @@ -342,13 +343,14 @@ export class StdoutAuditConnector extends BaseAuditConnector {

onCreateImport(
req: Request,
mirroredModelId: string,
mirroredModel: ModelInterface,
sourceModelId: string,
modelCardVersions: number[],
exporter: string,
newModelCards: ModelCardRevisionInterface[],
) {
this.checkEventType(AuditInfo.CreateImport, req)
const event = this.generateEvent(req, { mirroredModelId, sourceModelId, modelCardVersions, exporter })
const event = this.generateEvent(req, { mirroredModel, sourceModelId, modelCardVersions, exporter, newModelCards })
req.log.info(event, req.audit.description)
}
}
12 changes: 2 additions & 10 deletions backend/src/models/ModelCardRevision.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import { Document, model, Schema } from 'mongoose'

import { ModelMetadata } from './Model.js'
import { ModelCardInterface } from './Model.js'

// This interface stores information about the properties on the base object.
// It should be used for plain object representations, e.g. for sending to the
// client.
export interface ModelCardRevisionInterface {
export interface ModelCardRevisionInterface extends ModelCardInterface {
modelId: string
schemaId: string

version: number
metadata: ModelMetadata

createdBy: string
createdAt: Date
updatedAt: Date
}

// The doc type includes all values in the plain interface, as well as all the
Expand Down
6 changes: 3 additions & 3 deletions backend/src/routes/v2/model/modelcard/getModelCardHtml.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { z } from 'zod'

import { AuditInfo } from '../../../../connectors/audit/Base.js'
import audit from '../../../../connectors/audit/index.js'
import { renderToHtml } from '../../../../services/modelCardExport.js'
import { getModelCardHtml as getModelCardHtmlService } from '../../../../services/modelCardExport.js'
import { registerPath } from '../../../../services/specification.js'
import { GetModelCardVersionOptions } from '../../../../types/enums.js'
import { parse } from '../../../../utils/validate.js'
Expand Down Expand Up @@ -46,8 +46,8 @@ export const getModelCardHtml = [
params: { modelId, version },
} = parse(req, getModelCardHtmlSchema)

const { html, card } = await renderToHtml(req.user, modelId, version)
await audit.onViewModelCard(req, modelId, card)
const { html, modelCard } = await getModelCardHtmlService(req.user, modelId, version)
await audit.onViewModelCard(req, modelId, modelCard)

return res.send(html)
},
Expand Down
17 changes: 9 additions & 8 deletions backend/src/routes/v2/model/postRequestImport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ export const postRequestImportFromS3 = [
body: { payloadUrl, mirroredModelId, exporter },
} = parse(req, postRequestImportFromS3Schema)

const importInfo = await importModel(req.user, mirroredModelId, payloadUrl)
await audit.onCreateImport(
req,
importInfo.mirroredModelId,
importInfo.sourceModelId,
importInfo.modelCardVersions,
exporter,
const { mirroredModel, sourceModelId, modelCardVersions, newModelCards } = await importModel(
mirroredModelId,
payloadUrl,
)
await audit.onCreateImport(req, mirroredModel, sourceModelId, modelCardVersions, exporter, newModelCards)

return res.json(importInfo)
return res.json({
mirroredModelId: mirroredModel.id,
sourceModelId,
modelCardVersions,
})
},
]
21 changes: 15 additions & 6 deletions backend/src/services/mirroredModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ export async function exportModel(
log.debug({ modelId, semvers }, 'Successfully finalized zip file.')
}

export async function importModel(_user: UserInterface, mirroredModelId: string, payloadUrl: string) {
export async function importModel(mirroredModelId: string, payloadUrl: string) {
if (!config.ui.modelMirror.import.enabled) {
throw BadReq('Importing models has not been enabled.')
}

if (mirroredModelId === '') {
throw BadReq('Missing mirrored model ID.')
}
let sourceModelId
let sourceModelId = ''

log.info({ mirroredModelId, payloadUrl }, 'Received a request to import a model.')

Expand Down Expand Up @@ -135,14 +135,23 @@ export async function importModel(_user: UserInterface, mirroredModelId: string,

log.info({ mirroredModelId, payloadUrl, sourceModelId }, 'Finished parsing the collection of model cards.')

await Promise.all(modelCards.map((card) => saveImportedModelCard(card, sourceModelId)))
await setLatestImportedModelCard(mirroredModelId)
const newModelCards = (
await Promise.all(modelCards.map((card) => saveImportedModelCard(card, sourceModelId)))
).filter((card): card is ModelCardRevisionInterface => !!card)

const mirroredModel = await setLatestImportedModelCard(mirroredModelId)

log.info(
{ mirroredModelId, payloadUrl, sourceModelId, modelCardVersions: modelCards.map((modelCard) => modelCard.version) },
'Finished importing the collection of model cards.',
)

return { mirroredModelId, sourceModelId, modelCardVersions: modelCards.map((modelCard) => modelCard.version) }
return {
mirroredModel,
sourceModelId,
modelCardVersions: modelCards.map((modelCard) => modelCard.version),
newModelCards,
}
}

function parseModelCard(modelCardJson: string, mirroredModelId: string, sourceModelId?: string) {
Expand All @@ -159,7 +168,7 @@ function parseModelCard(modelCardJson: string, mirroredModelId: string, sourceMo
if (sourceModelId !== modelId) {
throw InternalError('Zip file contains model cards for multiple models.', { modelIds: [sourceModelId, modelId] })
}
return { modelCard }
return { modelCard, sourceModelId }
}

async function uploadToS3(
Expand Down
35 changes: 24 additions & 11 deletions backend/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,12 @@ export async function createModelCardFromTemplate(
return revision
}

export async function saveImportedModelCard(modelCard: ModelCardRevisionInterface, sourceModelId: string) {
export async function saveImportedModelCard(modelCardRevision: ModelCardRevisionInterface, sourceModelId: string) {
const model = await Model.findOne({
id: modelCard.modelId,
id: modelCardRevision.modelId,
})
if (!model) {
throw NotFound(`Cannot find model to import model card.`, { modelId: modelCard.modelId })
throw NotFound(`Cannot find model to import model card.`, { modelId: modelCardRevision.modelId })
}
if (!model.settings.mirror.sourceModelId) {
throw InternalError('Cannot import model card to non mirrored model.')
Expand All @@ -406,44 +406,57 @@ export async function saveImportedModelCard(modelCard: ModelCardRevisionInterfac
})
}

const schema = await getSchemaById(modelCard.schemaId)
const schema = await getSchemaById(modelCardRevision.schemaId)
try {
new Validator().validate(modelCard.metadata, schema.jsonSchema, { throwAll: true, required: true })
new Validator().validate(modelCardRevision.metadata, schema.jsonSchema, { throwAll: true, required: true })
} catch (error) {
if (isValidatorResultError(error)) {
throw BadReq('Model metadata could not be validated against the schema.', {
schemaId: modelCard.schemaId,
schemaId: modelCardRevision.schemaId,
validationErrors: error.errors,
})
}
throw error
}

return await ModelCardRevisionModel.findOneAndUpdate(
{ modelId: modelCard.modelId, version: modelCard.version },
modelCard,
const foundModelCardRevision = await ModelCardRevisionModel.findOneAndUpdate(
{ modelId: modelCardRevision.modelId, version: modelCardRevision.version },
modelCardRevision,
{
upsert: true,
},
)

if (!foundModelCardRevision && modelCardRevision.version !== 1) {
// This model card did not already exist in Mongo, so it is a new model card. Return it to be audited.
// Ignore model cards with a version number of 1 as these will always be blank.
return modelCardRevision
}
}

/**
* Note that we do not authorise that the user can access the model here.
* This function should only be used during the import model card process.
* Do not expose this functionality to users.
*/
export async function setLatestImportedModelCard(modelId: string) {
const latestModelCard = await ModelCardRevisionModel.findOne({ modelId }, undefined, { sort: { version: -1 } })
if (!latestModelCard) {
throw NotFound('Cannot find latest model card.', { modelId })
}

const result = await ModelModel.findOneAndUpdate(
const updatedModel = await ModelModel.findOneAndUpdate(
{ id: modelId, 'settings.mirror.sourceModelId': { $exists: true, $ne: '' } },
{ $set: { card: latestModelCard } },
)
if (!result) {
if (!updatedModel) {
throw InternalError('Unable to set latest model card of mirrored model.', {
modelId,
version: latestModelCard.version,
})
}

return updatedModel
}

export function isModelCardRevision(data: unknown): data is ModelCardRevisionInterface {
Expand Down
43 changes: 26 additions & 17 deletions backend/src/services/modelCardExport.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { outdent } from 'outdent'
import showdown from 'showdown'

import { ModelInterface } from '../models/Model.js'
import { ModelCardRevisionInterface } from '../models/ModelCardRevision.js'
import { UserInterface } from '../models/User.js'
import { GetModelCardVersionOptionsKeys } from '../types/enums.js'
import { getModelById, getModelCard } from './model.js'
Expand Down Expand Up @@ -33,22 +35,33 @@ type Fragment = (
) &
Common

export async function renderToMarkdown(
export async function getModelCardHtml(
user: UserInterface,
modelId: string,
modelId: ModelInterface['id'],
version: number | GetModelCardVersionOptionsKeys,
) {
const model = await getModelById(user, modelId)
if (!model || !model.card) {
throw new Error('Trying to export model with no corresponding card')
if (!model) {
throw new Error('Failed to export model card. Model not found.')
}

const card = await getModelCard(user, modelId, version)
if (!card) {
throw new Error('Could not find specified model card')
const modelCard = await getModelCard(user, modelId, version)
if (!modelCard) {
throw new Error('Failed to find model card to export.')
}

const schema = await getSchemaById(card.schemaId)
const modelCardRevision: ModelCardRevisionInterface = { ...modelCard, modelId }
const html = await renderToHtml(model, modelCardRevision)

return { html, modelCard }
}

export async function renderToMarkdown(model: ModelInterface, modelCardRevision: ModelCardRevisionInterface) {
if (!model.card) {
throw new Error('Trying to export model with no corresponding card')
}

const schema = await getSchemaById(modelCardRevision.schemaId)
if (!schema) {
throw new Error('Trying to export model with no corresponding card')
}
Expand All @@ -59,16 +72,12 @@ export async function renderToMarkdown(
`

// 'Fragment' is a more strictly typed version of 'JsonSchema'.
output = recursiveRender(card.metadata, schema.jsonSchema as Fragment, output)
return { markdown: output, card }
output = recursiveRender(modelCardRevision.metadata, schema.jsonSchema as Fragment, output)
return output
}

export async function renderToHtml(
user: UserInterface,
modelId: string,
version: number | GetModelCardVersionOptionsKeys,
) {
const { markdown, card } = await renderToMarkdown(user, modelId, version)
export async function renderToHtml(model: ModelInterface, modelCardRevision: ModelCardRevisionInterface) {
const markdown = await renderToMarkdown(model, modelCardRevision)
const converter = new showdown.Converter()
converter.setFlavor('github')
const body = converter.makeHtml(markdown)
Expand Down Expand Up @@ -100,7 +109,7 @@ export async function renderToHtml(
</html>
`

return { html, card }
return html
}

function recursiveRender(obj: any, schema: Fragment, output = '', depth = 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ exports[`routes > model > postRequestImport > 200 > ok 1`] = `
}
`;

exports[`routes > model > postRequestImport > audit > expected call 1`] = `"abc"`;
exports[`routes > model > postRequestImport > audit > expected call 1`] = `
{
"id": "abc",
}
`;
Loading

0 comments on commit 5f6b52d

Please sign in to comment.