Skip to content

Commit

Permalink
Merge pull request #1282 from gchq/BAI-1247-create-a-basic-read-only-…
Browse files Browse the repository at this point in the history
…model

Bai 1247 create a basic read only model
  • Loading branch information
TT38665 authored Jun 4, 2024
2 parents e16a34f + bf6fb63 commit 8b53693
Show file tree
Hide file tree
Showing 13 changed files with 342 additions and 52 deletions.
12 changes: 7 additions & 5 deletions backend/src/models/Model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ export interface ModelCardInterface {
metadata: ModelMetadata
}

export interface Settings {
ungovernedAccess: boolean
mirror: { sourceModelId?: string; destinationModelId?: string }
}

// This interface stores information about the properties on the base object.
// It should be used for plain object representations, e.g. for sending to the
// client.
Expand All @@ -51,10 +56,7 @@ export interface ModelInterface {
card?: ModelCardInterface

collaborators: Array<CollaboratorEntry>
settings: {
ungovernedAccess: boolean
mirroredModelId?: string
}
settings: Settings

visibility: EntryVisibilityKeys
deleted: boolean
Expand Down Expand Up @@ -92,7 +94,7 @@ const ModelSchema = new Schema<ModelInterface>(
],
settings: {
ungovernedAccess: { type: Boolean, required: true, default: false },
mirroredModelId: { type: String },
mirror: { sourceModelId: { type: String }, destinationModelId: { type: String } },
},

visibility: { type: String, enum: Object.values(EntryVisibility), default: EntryVisibility.Public },
Expand Down
9 changes: 7 additions & 2 deletions backend/src/routes/v2/model/patchModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ export const patchModelSchema = z.object({
visibility: z.nativeEnum(EntryVisibility).optional().openapi({ example: 'private' }),
settings: z
.object({
ungovernedAccess: z.boolean().optional().default(false).openapi({ example: true }),
mirroredModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
ungovernedAccess: z.boolean().optional().openapi({ example: true }),
mirror: z
.object({
sourceModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
destinationModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
})
.optional(),
})
.optional(),
collaborators: z
Expand Down
10 changes: 9 additions & 1 deletion backend/src/routes/v2/model/postModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ export const postModelSchema = z.object({
settings: z
.object({
ungovernedAccess: z.boolean().optional().default(false).openapi({ example: true }),
mirror: z
.object({
sourceModelId: z.string().openapi({ example: 'yolo-v4-abcdef' }).optional(),
destinationModelId: z.string().openapi({ example: 'yolo-v4-abcdef' }).optional(),
})
.optional()
.default({}),
})
.optional(),
.optional()
.default({ ungovernedAccess: false }),
}),
})

Expand Down
1 change: 1 addition & 0 deletions backend/src/seeds/data/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const model: ModelInterface = {
],
settings: {
ungovernedAccess: true,
mirror: {},
},

card: {
Expand Down
1 change: 1 addition & 0 deletions backend/src/seeds/disable_ungoverned.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default async function DisableUngovernedModel() {
'This model does not allow ungoverned access, users should create an access request to get the artefacts from this.',
settings: {
ungovernedAccess: false,
mirror: {},
},
})
}
8 changes: 7 additions & 1 deletion backend/src/services/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import authorisation from '../connectors/authorisation/index.js'
import FileModel, { ScanState } from '../models/File.js'
import { UserInterface } from '../models/User.js'
import config from '../utils/config.js'
import { Forbidden, NotFound } from '../utils/error.js'
import { BadReq, Forbidden, NotFound } from '../utils/error.js'
import { longId } from '../utils/id.js'
import log from './log.js'
import { getModelById } from './model.js'
Expand All @@ -23,6 +23,9 @@ export async function uploadFile(
stream: ReadableStream,
) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot upload files to a mirrored model.`)
}

const fileId = longId()

Expand Down Expand Up @@ -135,6 +138,9 @@ export async function getFilesByIds(user: UserInterface, modelId: string, fileId
export async function removeFile(user: UserInterface, modelId: string, fileId: string) {
const model = await getModelById(user, modelId)
const file = await getFileById(user, fileId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot remove file from a mirrored model`)
}

const auth = await authorisation.file(user, model, file, FileAction.Delete)
if (!auth.success) {
Expand Down
6 changes: 3 additions & 3 deletions backend/src/services/mirroredModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ export async function exportModel(
throw BadReq('You must agree to the disclaimer agreement before being able to export a model.')
}
const model = await getModelById(user, modelId)
if (!model.settings.mirroredModelId || model.settings.mirroredModelId === '') {
if (!model.settings.mirror.destinationModelId) {
throw BadReq('The ID of the mirrored model has not been set on this model.')
}
const mirroredModelId = model.settings.mirroredModelId
const mirroredModelId = model.settings.mirror.destinationModelId
const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, model: model.id })
Expand Down Expand Up @@ -133,7 +133,7 @@ async function uploadToTemporaryS3Location(

async function getObjectFromTemporaryS3Location(modelId: string, semvers: string[] | undefined) {
const bucket = config.s3.buckets.uploads
const object = `exportQueue/${modelId}.zi`
const object = `exportQueue/${modelId}.zip`
try {
const stream = (await getObjectStream(bucket, object)).Body as Readable
log.debug(
Expand Down
44 changes: 32 additions & 12 deletions backend/src/services/model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Validator } from 'jsonschema'
import _ from 'lodash'

import authentication from '../connectors/authentication/index.js'
import { ModelAction, ModelActionKeys } from '../connectors/authorisation/actions.js'
Expand All @@ -14,7 +15,13 @@ import { BadReq, Forbidden, NotFound } from '../utils/error.js'
import { convertStringToId } from '../utils/id.js'
import { findSchemaById } from './schema.js'

export type CreateModelParams = Pick<ModelInterface, 'name' | 'kind' | 'teamId' | 'description' | 'visibility'>
export function checkModelRestriction(model: ModelInterface) {
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot alter a mirrored model.`)
}
}

export type CreateModelParams = Pick<ModelInterface, 'name' | 'teamId' | 'description' | 'visibility' | 'settings'>
export async function createModel(user: UserInterface, modelParams: CreateModelParams) {
const modelId = convertStringToId(modelParams.name)

Expand All @@ -32,10 +39,15 @@ export async function createModel(user: UserInterface, modelParams: CreateModelP
})

const auth = await authorisation.model(user, model, ModelAction.Create)

if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn })
}

if (modelParams?.settings?.mirror?.destinationModelId && modelParams?.settings?.mirror?.sourceModelId) {
throw BadReq('You cannot select both mirror settings simultaneously.')
}

await model.save()

return model
Expand Down Expand Up @@ -208,12 +220,15 @@ export async function _setModelCard(
// It is assumed that this race case will occur infrequently.
const model = await getModelById(user, modelId)

checkModelRestriction(model)

const auth = await authorisation.model(user, model, ModelAction.Write)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, modelId })
}

// We don't want to copy across other values

const newDocument = {
schemaId: schemaId,

Expand All @@ -235,11 +250,7 @@ export async function updateModelCard(
metadata: unknown,
): Promise<ModelCardRevisionDoc> {
const model = await getModelById(user, modelId)

const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, modelId })
}
checkModelRestriction(model)

if (!model.card) {
throw BadReq(`This model must first be instantiated before it can be `, { modelId })
Expand All @@ -262,19 +273,27 @@ export async function updateModelCard(
return revision
}

export type UpdateModelParams = Pick<
ModelInterface,
'name' | 'description' | 'visibility' | 'collaborators' | 'settings'
>
export async function updateModel(user: UserInterface, modelId: string, diff: Partial<UpdateModelParams>) {
export type UpdateModelParams = Pick<ModelInterface, 'name' | 'teamId' | 'description' | 'visibility'> & {
settings: Partial<ModelInterface['settings']>
}
export async function updateModel(user: UserInterface, modelId: string, modelDiff: Partial<UpdateModelParams>) {
const model = await getModelById(user, modelId)
if (modelDiff.settings?.mirror?.sourceModelId) {
throw BadReq('Cannot change standard model to be a mirrored model.')
}
if (model.settings.mirror.sourceModelId && modelDiff.settings?.mirror?.destinationModelId) {
throw BadReq('Cannot set a destination model ID for a mirrored model.')
}
if (modelDiff.settings?.mirror?.destinationModelId && modelDiff.settings?.mirror?.sourceModelId) {
throw BadReq('You cannot select both mirror settings simultaneously.')
}

const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn })
}

Object.assign(model, diff)
_.mergeWith(model, modelDiff, (a, b) => (_.isArray(b) ? b : undefined))
await model.save()

return model
Expand All @@ -286,6 +305,7 @@ export async function createModelCardFromSchema(
schemaId: string,
): Promise<ModelCardRevisionDoc> {
const model = await getModelById(user, modelId)
checkModelRestriction(model)

const auth = await authorisation.model(user, model, ModelAction.Write)
if (!auth.success) {
Expand Down
22 changes: 21 additions & 1 deletion backend/src/services/release.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ export type CreateReleaseParams = Optional<
>
export async function createRelease(user: UserInterface, releaseParams: CreateReleaseParams) {
const model = await getModelById(user, releaseParams.modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot create a release from a mirrored model.`)
}

if (releaseParams.modelCardVersion) {
// Ensure that the requested model card version exists.
Expand Down Expand Up @@ -175,6 +178,9 @@ export async function createRelease(user: UserInterface, releaseParams: CreateRe
export type UpdateReleaseParams = Pick<ReleaseInterface, 'notes' | 'draft' | 'fileIds' | 'images'>
export async function updateRelease(user: UserInterface, modelId: string, semver: string, delta: UpdateReleaseParams) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot update a release on a mirrored model.`)
}
const release = await getReleaseBySemver(user, modelId, semver)

Object.assign(release, delta)
Expand All @@ -198,8 +204,11 @@ export async function updateRelease(user: UserInterface, modelId: string, semver
}

export async function newReleaseComment(user: UserInterface, modelId: string, semver: string, message: string) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot create a new comment on a mirrored model.`)
}
const release = await Release.findOne({ modelId, semver })

if (!release) {
throw NotFound(`The requested release was not found.`, { modelId, semver })
}
Expand Down Expand Up @@ -234,6 +243,10 @@ export async function updateReleaseComment(
message: string,
) {
const release = await Release.findOne({ modelId, semver })
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot update comments on a mirrored model.`)
}

if (!release) {
throw NotFound(`The requested release was not found.`, { modelId, semver })
Expand Down Expand Up @@ -336,6 +349,9 @@ export async function getReleaseBySemver(user: UserInterface, modelId: string, s

export async function deleteRelease(user: UserInterface, modelId: string, semver: string) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot delete a release on a mirrored model.`)
}
const release = await getReleaseBySemver(user, modelId, semver)

const auth = await authorisation.release(user, model, release, ReleaseAction.Delete)
Expand All @@ -353,6 +369,10 @@ export function getReleaseName(release: ReleaseDoc): string {
}

export async function removeFileFromReleases(user: UserInterface, model: ModelDoc, fileId: string) {
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot remove a file from a mirrored model.`)
}

const query = {
modelId: model.id,
// Match documents where the element exists in the array
Expand Down
34 changes: 33 additions & 1 deletion backend/test/services/file.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const s3Mocks = vi.hoisted(() => ({
vi.mock('../../src/clients/s3.js', () => s3Mocks)

const modelMocks = vi.hoisted(() => ({
getModelById: vi.fn(),
getModelById: vi.fn(() => ({ settings: { mirror: { sourceModelId: '' } } })),
}))
vi.mock('../../src/services/model.js', () => modelMocks)

Expand Down Expand Up @@ -128,6 +128,24 @@ describe('services > file', () => {
)
})

test('uploadFile > should throw an error when attempting to upload a file to a mirrored model', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
const name = 'testFile'
const mime = 'text/plain'
const stream = new Readable() as any

vi.mocked(authorisation.file).mockResolvedValue({
info: 'Cannot upload files to a mirrored model.',
success: false,
id: '',
})

expect(() => uploadFile(user, modelId, name, mime, stream)).rejects.toThrowError(
/^Cannot upload files to a mirrored model./,
)
expect(fileModelMocks.save).not.toBeCalled()
})
test('removeFile > success', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
Expand Down Expand Up @@ -171,6 +189,20 @@ describe('services > file', () => {
expect(fileModelMocks.delete).not.toBeCalled()
})

test('removeFile > should throw an error when attempting to remove a file from a mirrored model', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
const fileId = 'testFileId'

vi.mocked(authorisation.file).mockResolvedValue({
info: 'Cannot remove file from a mirrored model.',
success: false,
id: '',
})
expect(() => removeFile(user, modelId, fileId)).rejects.toThrowError(/^Cannot remove file from a mirrored model./)
expect(fileModelMocks.delete).not.toBeCalled()
})

test('getFilesByModel > success', async () => {
fileModelMocks.find.mockResolvedValueOnce([{ example: 'file' }])

Expand Down
6 changes: 4 additions & 2 deletions backend/test/services/mirroredModel.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ vi.mock('../../src/services/log.js', async () => ({
}))

const modelMocks = vi.hoisted(() => ({
getModelById: vi.fn(() => ({ settings: { mirroredModelId: 'abc' } })),
getModelById: vi.fn(() => ({ settings: { mirror: { destinationModelId: '123' } } })),
getModelCardRevisions: vi.fn(() => [{ toJSON: vi.fn(), version: 123 }]),
}))
vi.mock('../../src/services/model.js', () => modelMocks)
Expand Down Expand Up @@ -183,7 +183,9 @@ describe('services > mirroredModel', () => {
})

test('exportModel > missing mirrored model ID', async () => {
modelMocks.getModelById.mockReturnValueOnce({ settings: { mirroredModelId: '' } })
modelMocks.getModelById.mockReturnValueOnce({
settings: { mirror: { destinationModelId: '' } },
})
const response = exportModel({} as UserInterface, 'modelId', true, ['1.2.3'])
expect(response).rejects.toThrowError(/^The ID of the mirrored model has not been set on this model./)
expect(s3Mocks.putObjectStream).toHaveBeenCalledTimes(0)
Expand Down
Loading

0 comments on commit 8b53693

Please sign in to comment.