From e9eaec0d641b67c56bc1c3a0ffe0f6c37614f054 Mon Sep 17 00:00:00 2001 From: ARADDCC012 <110473008+ARADDCC012@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:13:07 +0000 Subject: [PATCH 1/2] Updated model import flow and html rendering flow to allow for auditing of new model cards during import --- backend/src/connectors/audit/Base.ts | 5 +- backend/src/connectors/audit/silly.ts | 4 +- backend/src/connectors/audit/stdout.ts | 6 +- backend/src/models/ModelCardRevision.ts | 12 +--- .../v2/model/modelcard/getModelCardHtml.ts | 6 +- .../src/routes/v2/model/postRequestImport.ts | 17 +++--- backend/src/services/mirroredModel.ts | 21 +++++-- backend/src/services/model.ts | 35 ++++++++---- backend/src/services/modelCardExport.ts | 39 +++++++------ .../model/modelcard/getModelCardHtml.spec.ts | 30 +++++++--- .../routes/model/postRequestImport.spec.ts | 7 ++- backend/test/services/mirroredModel.spec.ts | 18 +++--- backend/test/services/modelCardExport.spec.ts | 56 +++++++++---------- 13 files changed, 146 insertions(+), 110 deletions(-) diff --git a/backend/src/connectors/audit/Base.ts b/backend/src/connectors/audit/Base.ts index a920f65e4..2f23dd155 100644 --- a/backend/src/connectors/audit/Base.ts +++ b/backend/src/connectors/audit/Base.ts @@ -183,17 +183,18 @@ export abstract class BaseAuditConnector { abstract onCreateS3Export(req: Request, modelId: string, semvers?: string[]) abstract onCreateImport( req: Request, - mirroredModelId: string, + mirroredModel: ModelInterface, sourceModelId: string, modelCardVersions: number[], exporter: string, + newModelCards: ModelCardInterface[], ) abstract onError(req: Request, error: BailoError) checkEventType(auditInfo: AuditInfoKeys, req: Request) { if (auditInfo.typeId !== req.audit.typeId && auditInfo.description !== req.audit.description) { - throw new Error(`Audit: Expected type '${JSON.stringify(auditInfo)}' but recieved '${JSON.stringify(req.audit)}'`) + throw new Error(`Audit: Expected type '${JSON.stringify(auditInfo)}' but received '${JSON.stringify(req.audit)}'`) } } } diff --git a/backend/src/connectors/audit/silly.ts b/backend/src/connectors/audit/silly.ts index 0302b5945..c7f26cd7a 100644 --- a/backend/src/connectors/audit/silly.ts +++ b/backend/src/connectors/audit/silly.ts @@ -5,6 +5,7 @@ import { AccessRequestDoc } from '../../models/AccessRequest.js' import { FileInterface, FileInterfaceDoc } from '../../models/File.js' import { InferenceDoc } from '../../models/Inference.js' import { ModelCardInterface, ModelDoc, ModelInterface } from '../../models/Model.js' +import { ModelCardRevisionInterface } from '../../models/ModelCardRevision.js' import { ReleaseDoc } from '../../models/Release.js' import { ResponseInterface } from '../../models/Response.js' import { ReviewInterface } from '../../models/Review.js' @@ -59,10 +60,11 @@ export class SillyAuditConnector extends BaseAuditConnector { onCreateS3Export(_req: Request, _modelId: string, _semvers?: string[]) {} onCreateImport( _req: Request, - _mirroredModelId: string, + _mirroredModel: ModelInterface, _sourceModelId: string, _modelCardVersions: number[], _exporter: string, + _newModelCards: ModelCardRevisionInterface[], ) {} onError(_req: Request, _error: BailoError) {} onCreateCommentResponse(_req: Request, _responseInterface: ResponseInterface) {} diff --git a/backend/src/connectors/audit/stdout.ts b/backend/src/connectors/audit/stdout.ts index 65464eefc..58365c8a8 100644 --- a/backend/src/connectors/audit/stdout.ts +++ b/backend/src/connectors/audit/stdout.ts @@ -4,6 +4,7 @@ import { AccessRequestDoc } from '../../models/AccessRequest.js' import { FileInterface, FileInterfaceDoc } from '../../models/File.js' import { InferenceDoc } from '../../models/Inference.js' import { ModelCardInterface, ModelDoc, ModelInterface } from '../../models/Model.js' +import { ModelCardRevisionInterface } from '../../models/ModelCardRevision.js' import { ReleaseDoc } from '../../models/Release.js' import { ResponseInterface } from '../../models/Response.js' import { ReviewInterface } from '../../models/Review.js' @@ -342,13 +343,14 @@ export class StdoutAuditConnector extends BaseAuditConnector { onCreateImport( req: Request, - mirroredModelId: string, + mirroredModel: ModelInterface, sourceModelId: string, modelCardVersions: number[], exporter: string, + newModelCards: ModelCardRevisionInterface[], ) { this.checkEventType(AuditInfo.CreateImport, req) - const event = this.generateEvent(req, { mirroredModelId, sourceModelId, modelCardVersions, exporter }) + const event = this.generateEvent(req, { mirroredModel, sourceModelId, modelCardVersions, exporter, newModelCards }) req.log.info(event, req.audit.description) } } diff --git a/backend/src/models/ModelCardRevision.ts b/backend/src/models/ModelCardRevision.ts index 7432d7708..ce026a571 100644 --- a/backend/src/models/ModelCardRevision.ts +++ b/backend/src/models/ModelCardRevision.ts @@ -1,20 +1,12 @@ import { Document, model, Schema } from 'mongoose' -import { ModelMetadata } from './Model.js' +import { ModelCardInterface } from './Model.js' // This interface stores information about the properties on the base object. // It should be used for plain object representations, e.g. for sending to the // client. -export interface ModelCardRevisionInterface { +export interface ModelCardRevisionInterface extends ModelCardInterface { modelId: string - schemaId: string - - version: number - metadata: ModelMetadata - - createdBy: string - createdAt: Date - updatedAt: Date } // The doc type includes all values in the plain interface, as well as all the diff --git a/backend/src/routes/v2/model/modelcard/getModelCardHtml.ts b/backend/src/routes/v2/model/modelcard/getModelCardHtml.ts index bb5728ddb..45655db98 100644 --- a/backend/src/routes/v2/model/modelcard/getModelCardHtml.ts +++ b/backend/src/routes/v2/model/modelcard/getModelCardHtml.ts @@ -4,7 +4,7 @@ import { z } from 'zod' import { AuditInfo } from '../../../../connectors/audit/Base.js' import audit from '../../../../connectors/audit/index.js' -import { renderToHtml } from '../../../../services/modelCardExport.js' +import { getModelCardHtml as getModelCardHtmlService } from '../../../../services/modelCardExport.js' import { registerPath } from '../../../../services/specification.js' import { GetModelCardVersionOptions } from '../../../../types/enums.js' import { parse } from '../../../../utils/validate.js' @@ -46,8 +46,8 @@ export const getModelCardHtml = [ params: { modelId, version }, } = parse(req, getModelCardHtmlSchema) - const { html, card } = await renderToHtml(req.user, modelId, version) - await audit.onViewModelCard(req, modelId, card) + const { html, modelCard } = await getModelCardHtmlService(req.user, modelId, version) + await audit.onViewModelCard(req, modelId, modelCard) return res.send(html) }, diff --git a/backend/src/routes/v2/model/postRequestImport.ts b/backend/src/routes/v2/model/postRequestImport.ts index b0f78edd2..958769f53 100644 --- a/backend/src/routes/v2/model/postRequestImport.ts +++ b/backend/src/routes/v2/model/postRequestImport.ts @@ -52,15 +52,16 @@ export const postRequestImportFromS3 = [ body: { payloadUrl, mirroredModelId, exporter }, } = parse(req, postRequestImportFromS3Schema) - const importInfo = await importModel(req.user, mirroredModelId, payloadUrl) - await audit.onCreateImport( - req, - importInfo.mirroredModelId, - importInfo.sourceModelId, - importInfo.modelCardVersions, - exporter, + const { mirroredModel, sourceModelId, modelCardVersions, newModelCards } = await importModel( + mirroredModelId, + payloadUrl, ) + await audit.onCreateImport(req, mirroredModel, sourceModelId, modelCardVersions, exporter, newModelCards) - return res.json(importInfo) + return res.json({ + mirroredModelId: mirroredModel.id, + sourceModelId, + modelCardVersions, + }) }, ] diff --git a/backend/src/services/mirroredModel.ts b/backend/src/services/mirroredModel.ts index debd4128a..3e0399129 100644 --- a/backend/src/services/mirroredModel.ts +++ b/backend/src/services/mirroredModel.ts @@ -77,7 +77,7 @@ export async function exportModel( log.debug({ modelId, semvers }, 'Successfully finalized zip file.') } -export async function importModel(_user: UserInterface, mirroredModelId: string, payloadUrl: string) { +export async function importModel(mirroredModelId: string, payloadUrl: string) { if (!config.ui.modelMirror.import.enabled) { throw BadReq('Importing models has not been enabled.') } @@ -85,7 +85,7 @@ export async function importModel(_user: UserInterface, mirroredModelId: string, if (mirroredModelId === '') { throw BadReq('Missing mirrored model ID.') } - let sourceModelId + let sourceModelId = '' log.info({ mirroredModelId, payloadUrl }, 'Received a request to import a model.') @@ -135,14 +135,23 @@ export async function importModel(_user: UserInterface, mirroredModelId: string, log.info({ mirroredModelId, payloadUrl, sourceModelId }, 'Finished parsing the collection of model cards.') - await Promise.all(modelCards.map((card) => saveImportedModelCard(card, sourceModelId))) - await setLatestImportedModelCard(mirroredModelId) + const newModelCards = ( + await Promise.all(modelCards.map((card) => saveImportedModelCard(card, sourceModelId))) + ).filter((card): card is ModelCardRevisionInterface => !!card) + + const mirroredModel = await setLatestImportedModelCard(mirroredModelId) + log.info( { mirroredModelId, payloadUrl, sourceModelId, modelCardVersions: modelCards.map((modelCard) => modelCard.version) }, 'Finished importing the collection of model cards.', ) - return { mirroredModelId, sourceModelId, modelCardVersions: modelCards.map((modelCard) => modelCard.version) } + return { + mirroredModel, + sourceModelId, + modelCardVersions: modelCards.map((modelCard) => modelCard.version), + newModelCards, + } } function parseModelCard(modelCardJson: string, mirroredModelId: string, sourceModelId?: string) { @@ -159,7 +168,7 @@ function parseModelCard(modelCardJson: string, mirroredModelId: string, sourceMo if (sourceModelId !== modelId) { throw InternalError('Zip file contains model cards for multiple models.', { modelIds: [sourceModelId, modelId] }) } - return { modelCard } + return { modelCard, sourceModelId } } async function uploadToS3( diff --git a/backend/src/services/model.ts b/backend/src/services/model.ts index a6e200511..e363168d0 100644 --- a/backend/src/services/model.ts +++ b/backend/src/services/model.ts @@ -389,12 +389,12 @@ export async function createModelCardFromTemplate( return revision } -export async function saveImportedModelCard(modelCard: ModelCardRevisionInterface, sourceModelId: string) { +export async function saveImportedModelCard(modelCardRevision: ModelCardRevisionInterface, sourceModelId: string) { const model = await Model.findOne({ - id: modelCard.modelId, + id: modelCardRevision.modelId, }) if (!model) { - throw NotFound(`Cannot find model to import model card.`, { modelId: modelCard.modelId }) + throw NotFound(`Cannot find model to import model card.`, { modelId: modelCardRevision.modelId }) } if (!model.settings.mirror.sourceModelId) { throw InternalError('Cannot import model card to non mirrored model.') @@ -406,44 +406,57 @@ export async function saveImportedModelCard(modelCard: ModelCardRevisionInterfac }) } - const schema = await getSchemaById(modelCard.schemaId) + const schema = await getSchemaById(modelCardRevision.schemaId) try { - new Validator().validate(modelCard.metadata, schema.jsonSchema, { throwAll: true, required: true }) + new Validator().validate(modelCardRevision.metadata, schema.jsonSchema, { throwAll: true, required: true }) } catch (error) { if (isValidatorResultError(error)) { throw BadReq('Model metadata could not be validated against the schema.', { - schemaId: modelCard.schemaId, + schemaId: modelCardRevision.schemaId, validationErrors: error.errors, }) } throw error } - return await ModelCardRevisionModel.findOneAndUpdate( - { modelId: modelCard.modelId, version: modelCard.version }, - modelCard, + const foundModelCardRevision = await ModelCardRevisionModel.findOneAndUpdate( + { modelId: modelCardRevision.modelId, version: modelCardRevision.version }, + modelCardRevision, { upsert: true, }, ) + + if (!foundModelCardRevision && modelCardRevision.version !== 1) { + // This model card did not already exist in Mongo, so it is a new model card. Return it to be audited. + // Ignore model cards with a version number of 1 as these will always be blank. + return modelCardRevision + } } +/** + * Note that we do not authorise that the user can access the model here. + * This function should only be used during the import model card process. + * Do not expose this functionality to users. + */ export async function setLatestImportedModelCard(modelId: string) { const latestModelCard = await ModelCardRevisionModel.findOne({ modelId }, undefined, { sort: { version: -1 } }) if (!latestModelCard) { throw NotFound('Cannot find latest model card.', { modelId }) } - const result = await ModelModel.findOneAndUpdate( + const updatedModel = await ModelModel.findOneAndUpdate( { id: modelId, 'settings.mirror.sourceModelId': { $exists: true, $ne: '' } }, { $set: { card: latestModelCard } }, ) - if (!result) { + if (!updatedModel) { throw InternalError('Unable to set latest model card of mirrored model.', { modelId, version: latestModelCard.version, }) } + + return updatedModel } export function isModelCardRevision(data: unknown): data is ModelCardRevisionInterface { diff --git a/backend/src/services/modelCardExport.ts b/backend/src/services/modelCardExport.ts index 711c5ba30..703176bf3 100644 --- a/backend/src/services/modelCardExport.ts +++ b/backend/src/services/modelCardExport.ts @@ -1,6 +1,8 @@ import { outdent } from 'outdent' import showdown from 'showdown' +import { ModelInterface } from '../models/Model.js' +import { ModelCardRevisionInterface } from '../models/ModelCardRevision.js' import { UserInterface } from '../models/User.js' import { GetModelCardVersionOptionsKeys } from '../types/enums.js' import { getModelById, getModelCard } from './model.js' @@ -33,22 +35,29 @@ type Fragment = ( ) & Common -export async function renderToMarkdown( +export async function getModelCardHtml( user: UserInterface, - modelId: string, + modelId: ModelInterface['id'], version: number | GetModelCardVersionOptionsKeys, ) { const model = await getModelById(user, modelId) - if (!model || !model.card) { - throw new Error('Trying to export model with no corresponding card') + if (!model) { + throw new Error('Failed to export model card. Model not found.') } - const card = await getModelCard(user, modelId, version) - if (!card) { - throw new Error('Could not find specified model card') + const modelCard = await getModelCard(user, modelId, version) + if (!modelCard) { + throw new Error('Failed to find model card to export.') } - const schema = await getSchemaById(card.schemaId) + const modelCardRevision: ModelCardRevisionInterface = { ...modelCard, modelId } + const html = await renderToHtml(model, modelCardRevision) + + return { html, modelCard } +} + +export async function renderToMarkdown(model: ModelInterface, modelCardRevision: ModelCardRevisionInterface) { + const schema = await getSchemaById(modelCardRevision.schemaId) if (!schema) { throw new Error('Trying to export model with no corresponding card') } @@ -59,16 +68,12 @@ export async function renderToMarkdown( ` // 'Fragment' is a more strictly typed version of 'JsonSchema'. - output = recursiveRender(card.metadata, schema.jsonSchema as Fragment, output) - return { markdown: output, card } + output = recursiveRender(modelCardRevision.metadata, schema.jsonSchema as Fragment, output) + return output } -export async function renderToHtml( - user: UserInterface, - modelId: string, - version: number | GetModelCardVersionOptionsKeys, -) { - const { markdown, card } = await renderToMarkdown(user, modelId, version) +export async function renderToHtml(model: ModelInterface, modelCardRevision: ModelCardRevisionInterface) { + const markdown = await renderToMarkdown(model, modelCardRevision) const converter = new showdown.Converter() converter.setFlavor('github') const body = converter.makeHtml(markdown) @@ -100,7 +105,7 @@ export async function renderToHtml( ` - return { html, card } + return html } function recursiveRender(obj: any, schema: Fragment, output = '', depth = 1) { diff --git a/backend/test/routes/model/modelcard/getModelCardHtml.spec.ts b/backend/test/routes/model/modelcard/getModelCardHtml.spec.ts index 628d1c07d..267d7b874 100644 --- a/backend/test/routes/model/modelcard/getModelCardHtml.spec.ts +++ b/backend/test/routes/model/modelcard/getModelCardHtml.spec.ts @@ -1,32 +1,44 @@ import { describe, expect, test, vi } from 'vitest' import audit from '../../../../src/connectors/audit/index.js' +import { ModelCardInterface } from '../../../../src/models/Model.js' import { UserInterface } from '../../../../src/models/User.js' import { getModelCardHtmlSchema } from '../../../../src/routes/v2/model/modelcard/getModelCardHtml.js' -import { renderToHtml } from '../../../../src/services/modelCardExport.js' +import { getModelCardHtml as getModelCardHtmlService } from '../../../../src/services/modelCardExport.js' import { createFixture, testGet } from '../../../testUtils/routes.js' vi.mock('../../../../src/utils/user.js') vi.mock('../../../../src/connectors/audit/index.js') +const mockModelCard: ModelCardInterface = { + schemaId: 'schema123', + version: 1, + createdBy: 'Joe Bloggs', + metadata: {}, +} +const mockUser: UserInterface = { dn: 'user' } + const mockModelCardExportService = vi.hoisted(() => { return { - renderToHtml: vi.fn(() => ({ html: 'test', card: 'card' })), + getModelCardHtml: vi.fn(() => ({ html: 'html', modelCard: mockModelCard })), } }) vi.mock('../../../../src/services/modelCardExport.js', () => mockModelCardExportService) describe('routes > model > modelcard > getModelCardHtml', () => { - test('should return HTML and call audit', async () => { - const testUser = { dn: 'user' } as UserInterface - mockModelCardExportService.renderToHtml.mockResolvedValueOnce({ html: 'test', card: 'card' }) - + test('should return html', async () => { const fixture = createFixture(getModelCardHtmlSchema) const res = await testGet(`/api/v2/model/${fixture.params.modelId}/model-card/${fixture.params.version}/html`) - expect(renderToHtml).toHaveBeenCalledWith(testUser, fixture.params.modelId, fixture.params.version) - expect(audit.onViewModelCard).toHaveBeenCalled() + expect(getModelCardHtmlService).toHaveBeenCalledWith(mockUser, fixture.params.modelId, fixture.params.version) expect(res.statusCode).toBe(200) - expect(res.text).toBe('test') + expect(res.text).toBe('html') + }) + + test('should call audit', async () => { + const fixture = createFixture(getModelCardHtmlSchema) + await testGet(`/api/v2/model/${fixture.params.modelId}/model-card/${fixture.params.version}/html`) + + expect(audit.onViewModelCard).toHaveBeenCalled() }) }) diff --git a/backend/test/routes/model/postRequestImport.spec.ts b/backend/test/routes/model/postRequestImport.spec.ts index 83c5e71bb..507330061 100644 --- a/backend/test/routes/model/postRequestImport.spec.ts +++ b/backend/test/routes/model/postRequestImport.spec.ts @@ -10,7 +10,12 @@ vi.mock('../../../src/connectors/audit/index.js') describe('routes > model > postRequestImport', () => { test('200 > ok', async () => { vi.mock('../../../src/services/mirroredModel.js', () => ({ - importModel: vi.fn(() => ({ mirroredModelId: 'abc', sourceModelId: 'cba', modelCardVersions: [1, 2, 3] })), + importModel: vi.fn(() => ({ + mirroredModel: { id: 'abc' }, + sourceModelId: 'cba', + modelCardVersions: [1, 2, 3], + newModelCards: [], + })), })) const fixture = createFixture(postRequestImportFromS3Schema) diff --git a/backend/test/services/mirroredModel.spec.ts b/backend/test/services/mirroredModel.spec.ts index 1ed2d11f1..578ce5a6a 100644 --- a/backend/test/services/mirroredModel.spec.ts +++ b/backend/test/services/mirroredModel.spec.ts @@ -334,34 +334,34 @@ describe('services > mirroredModel', () => { test('importModel > not enabled', async () => { vi.spyOn(configMock, 'ui', 'get').mockReturnValueOnce({ modelMirror: { import: { enabled: false } } }) - const result = importModel({} as UserInterface, '', 'https://test.com') + const result = importModel('', 'https://test.com') await expect(result).rejects.toThrowError('Importing models has not been enabled.') }) test('importModel > mirrored model Id empty', async () => { - const result = importModel({} as UserInterface, '', 'https://test.com') + const result = importModel('', 'https://test.com') await expect(result).rejects.toThrowError('Missing mirrored model ID.') }) test('importModel > error when getting zip file', async () => { fetchMock.default.mockRejectedValueOnce('a') - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError('Unable to get the file.') }) test('importModel > non 200 response when getting zip file', async () => { fetchMock.default.mockResolvedValueOnce({ ok: false, body: vi.fn(), text: vi.fn() }) - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError('Unable to get zip file.') }) test('importModel > file missing from body', async () => { fetchMock.default.mockResolvedValueOnce({ ok: true, text: vi.fn() } as any) - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError('Unable to get the file.') }) @@ -372,7 +372,7 @@ describe('services > mirroredModel', () => { file1: Buffer.from(JSON.stringify({ modelId: 'abc' })), file2: Buffer.from(JSON.stringify({ modelId: 'abc' })), }) - await importModel({} as UserInterface, 'model-id', 'https://test.com') + await importModel('model-id', 'https://test.com') await expect(modelMocks.saveImportedModelCard.mock.calls.length).toBe(2) }) @@ -383,7 +383,7 @@ describe('services > mirroredModel', () => { file1: Buffer.from(JSON.stringify({})), }) modelMocks.isModelCardRevision.mockReturnValueOnce(false) - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError(/^Data cannot be converted into a model card./) }) @@ -394,7 +394,7 @@ describe('services > mirroredModel', () => { file1: Buffer.from(JSON.stringify({ modelId: 'abc' })), file2: Buffer.from(JSON.stringify({ modelId: 'cba' })), }) - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError(/^Zip file contains model cards for multiple models./) }) @@ -404,7 +404,7 @@ describe('services > mirroredModel', () => { fflateMock.unzipSync.mockImplementationOnce(() => { throw Error('Cannot import file.') }) - const result = importModel({} as UserInterface, 'model-id', 'https://test.com') + const result = importModel('model-id', 'https://test.com') await expect(result).rejects.toThrowError(/^Unable to read zip file./) }) diff --git a/backend/test/services/modelCardExport.spec.ts b/backend/test/services/modelCardExport.spec.ts index 3c12bc48c..7615706c6 100644 --- a/backend/test/services/modelCardExport.spec.ts +++ b/backend/test/services/modelCardExport.spec.ts @@ -1,6 +1,8 @@ import { beforeEach, describe, expect, test, vi } from 'vitest' -import { getModelById, getModelCard } from '../../src/services/model.js' +import { ModelInterface } from '../../src/models/Model.js' +import { ModelCardRevisionInterface } from '../../src/models/ModelCardRevision.js' +import { getModelById } from '../../src/services/model.js' import { renderToHtml, renderToMarkdown } from '../../src/services/modelCardExport.js' import { getSchemaById } from '../../src/services/schema.js' @@ -8,54 +10,46 @@ vi.mock('../../src/services/model.js') vi.mock('../../src/services/schema.js') describe('services > export', () => { - const mockUser = { dn: 'testUser' } as any - const mockModelId = '123' - const mockVersion = 1 - const mockModel = { name: 'Test Model', description: 'Test Description', card: true } - const mockCard = { schemaId: 'schema123', metadata: {} } + const mockModelId = 'model123' + const mockSchemaId = 'schema123' + const mockModel = { name: 'Test Model', description: 'Test Description', card: {} } + const mockModelCardRevision: ModelCardRevisionInterface = { + modelId: mockModelId, + schemaId: mockSchemaId, + version: 1, + createdBy: 'Joe Bloggs', + metadata: {}, + } const mockSchema = { jsonSchema: { type: 'object', properties: {} } } beforeEach(() => { vi.mocked(getModelById).mockResolvedValue(mockModel as any) - vi.mocked(getModelCard).mockResolvedValue(mockCard as any) vi.mocked(getSchemaById).mockResolvedValue(mockSchema as any) }) - test('renderToMarkdown > should throw error if model has no card', async () => { - vi.mocked(getModelById).mockResolvedValueOnce({ ...mockModel, card: false } as any) + test('renderToMarkdown > should return markdown', async () => { + const result = await renderToMarkdown(mockModel as ModelInterface, mockModelCardRevision) - await expect(renderToMarkdown(mockUser, mockModelId, mockVersion)).rejects.toThrow( - 'Trying to export model with no corresponding card', - ) + expect(result).toContain('> Test Description') }) - test('renderToMarkdown > should throw error if card is not found', async () => { - vi.mocked(getModelCard).mockResolvedValueOnce(undefined as any) - - await expect(renderToMarkdown(mockUser, mockModelId, mockVersion)).rejects.toThrow( - 'Could not find specified model card', - ) + test('renderToHtml > should throw error if model has no card', async () => { + await expect( + renderToHtml({ ...mockModel, card: undefined } as ModelInterface, mockModelCardRevision), + ).rejects.toThrow('Trying to export model with no corresponding card') }) - test('renderToMarkdown > should throw error if schema is not found', async () => { + test('renderToHtml > should throw error if schema is not found', async () => { vi.mocked(getSchemaById).mockResolvedValueOnce(undefined as any) - await expect(renderToMarkdown(mockUser, mockModelId, mockVersion)).rejects.toThrow( + await expect(renderToHtml(mockModel as ModelInterface, mockModelCardRevision)).rejects.toThrow( 'Trying to export model with no corresponding card', ) }) - test('renderToMarkdown > should return markdown and card', async () => { - const result = await renderToMarkdown(mockUser, mockModelId, mockVersion) - - expect(result).toHaveProperty('markdown') - expect(result).toHaveProperty('card', mockCard) - }) - - test('renderToHtml > should return html and card', async () => { - const result = await renderToHtml(mockUser, mockModelId, mockVersion) + test('renderToHtml > should return html', async () => { + const result = await renderToHtml(mockModel as ModelInterface, mockModelCardRevision) - expect(result).toHaveProperty('html') - expect(result).toHaveProperty('card', mockCard) + expect(result).toContain('

Test Description

') }) }) From 0cc6435f0e1566b98c7417828fda2eb2ffb33f12 Mon Sep 17 00:00:00 2001 From: ARADDCC012 <110473008+ARADDCC012@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:20:04 +0000 Subject: [PATCH 2/2] Added check for model card during markdown rendering --- backend/src/services/modelCardExport.ts | 4 ++++ .../model/__snapshots__/postRequestImport.spec.ts.snap | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/backend/src/services/modelCardExport.ts b/backend/src/services/modelCardExport.ts index 703176bf3..ca2e824ff 100644 --- a/backend/src/services/modelCardExport.ts +++ b/backend/src/services/modelCardExport.ts @@ -57,6 +57,10 @@ export async function getModelCardHtml( } export async function renderToMarkdown(model: ModelInterface, modelCardRevision: ModelCardRevisionInterface) { + if (!model.card) { + throw new Error('Trying to export model with no corresponding card') + } + const schema = await getSchemaById(modelCardRevision.schemaId) if (!schema) { throw new Error('Trying to export model with no corresponding card') diff --git a/backend/test/routes/model/__snapshots__/postRequestImport.spec.ts.snap b/backend/test/routes/model/__snapshots__/postRequestImport.spec.ts.snap index 6d7215124..7b2c1e492 100644 --- a/backend/test/routes/model/__snapshots__/postRequestImport.spec.ts.snap +++ b/backend/test/routes/model/__snapshots__/postRequestImport.spec.ts.snap @@ -12,4 +12,8 @@ exports[`routes > model > postRequestImport > 200 > ok 1`] = ` } `; -exports[`routes > model > postRequestImport > audit > expected call 1`] = `"abc"`; +exports[`routes > model > postRequestImport > audit > expected call 1`] = ` +{ + "id": "abc", +} +`;