Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bai 1247 create a basic read only model #1282

Merged
merged 22 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions backend/src/models/Model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ export interface ModelCardInterface {
metadata: ModelMetadata
}

export interface Settings {
ungovernedAccess: boolean
mirror: { sourceModelId?: string; destinationModelId?: string }
}

// This interface stores information about the properties on the base object.
// It should be used for plain object representations, e.g. for sending to the
// client.
Expand All @@ -51,10 +56,7 @@ export interface ModelInterface {
card?: ModelCardInterface

collaborators: Array<CollaboratorEntry>
settings: {
ungovernedAccess: boolean
mirroredModelId?: string
}
settings: Settings

visibility: EntryVisibilityKeys
deleted: boolean
Expand Down Expand Up @@ -92,7 +94,7 @@ const ModelSchema = new Schema<ModelInterface>(
],
settings: {
ungovernedAccess: { type: Boolean, required: true, default: false },
mirroredModelId: { type: String },
mirror: { sourceModelId: { type: String }, destinationModelId: { type: String } },
},

visibility: { type: String, enum: Object.values(EntryVisibility), default: EntryVisibility.Public },
Expand Down
12 changes: 9 additions & 3 deletions backend/src/routes/v2/model/patchModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ export const patchModelSchema = z.object({
visibility: z.nativeEnum(EntryVisibility).optional().openapi({ example: 'private' }),
settings: z
.object({
ungovernedAccess: z.boolean().optional().default(false).openapi({ example: true }),
mirroredModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
ungovernedAccess: z.boolean().optional().openapi({ example: true }),
mirror: z
.object({
sourceModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
destinationModelId: z.string().optional().openapi({ example: 'yolo-v4-abcdef' }),
})
.optional(),
})
.optional(),
collaborators: z
Expand Down Expand Up @@ -66,7 +71,8 @@ export const patchModel = [
params: { modelId },
} = parse(req, patchModelSchema)

const model = await updateModel(req.user, modelId, body)
const { settings: settingsDiff, ...modelDiff } = body
const model = await updateModel(req.user, modelId, modelDiff, settingsDiff)
await audit.onUpdateModel(req, model)

return res.json({
Expand Down
10 changes: 9 additions & 1 deletion backend/src/routes/v2/model/postModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ export const postModelSchema = z.object({
settings: z
.object({
ungovernedAccess: z.boolean().optional().default(false).openapi({ example: true }),
mirror: z
.object({
sourceModelId: z.string().openapi({ example: 'yolo-v4-abcdef' }).optional(),
destinationModelId: z.string().openapi({ example: 'yolo-v4-abcdef' }).optional(),
})
.optional()
.default({}),
})
.optional(),
.optional()
.default({ ungovernedAccess: false }),
}),
})

Expand Down
1 change: 1 addition & 0 deletions backend/src/seeds/data/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const model: ModelInterface = {
],
settings: {
ungovernedAccess: true,
mirror: {},
},

card: {
Expand Down
1 change: 1 addition & 0 deletions backend/src/seeds/disable_ungoverned.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default async function DisableUngovernedModel() {
'This model does not allow ungoverned access, users should create an access request to get the artefacts from this.',
settings: {
ungovernedAccess: false,
mirror: {},
},
})
}
8 changes: 7 additions & 1 deletion backend/src/services/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import authorisation from '../connectors/authorisation/index.js'
import FileModel, { ScanState } from '../models/File.js'
import { UserInterface } from '../models/User.js'
import config from '../utils/config.js'
import { Forbidden, NotFound } from '../utils/error.js'
import { BadReq, Forbidden, NotFound } from '../utils/error.js'
import { longId } from '../utils/id.js'
import log from './log.js'
import { getModelById } from './model.js'
Expand All @@ -23,6 +23,9 @@ export async function uploadFile(
stream: ReadableStream,
) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot upload files to a mirrored model.`)
}

const fileId = longId()

Expand Down Expand Up @@ -135,6 +138,9 @@ export async function getFilesByIds(user: UserInterface, modelId: string, fileId
export async function removeFile(user: UserInterface, modelId: string, fileId: string) {
const model = await getModelById(user, modelId)
const file = await getFileById(user, fileId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot remove file from a mirrored model`)
}

const auth = await authorisation.file(user, model, file, FileAction.Delete)
if (!auth.success) {
Expand Down
6 changes: 3 additions & 3 deletions backend/src/services/mirroredModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ export async function exportModel(
throw BadReq('You must agree to the disclaimer agreement before being able to export a model.')
}
const model = await getModelById(user, modelId)
if (!model.settings.mirroredModelId || model.settings.mirroredModelId === '') {
if (!model.settings.mirror.destinationModelId) {
throw BadReq('The ID of the mirrored model has not been set on this model.')
}
const mirroredModelId = model.settings.mirroredModelId
const mirroredModelId = model.settings.mirror.destinationModelId
const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, model: model.id })
Expand Down Expand Up @@ -133,7 +133,7 @@ async function uploadToTemporaryS3Location(

async function getObjectFromTemporaryS3Location(modelId: string, semvers: string[] | undefined) {
const bucket = config.s3.buckets.uploads
const object = `exportQueue/${modelId}.zi`
const object = `exportQueue/${modelId}.zip`
try {
const stream = (await getObjectStream(bucket, object)).Body as Readable
log.debug(
Expand Down
54 changes: 42 additions & 12 deletions backend/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Validator } from 'jsonschema'
import authentication from '../connectors/authentication/index.js'
import { ModelAction, ModelActionKeys } from '../connectors/authorisation/actions.js'
import authorisation from '../connectors/authorisation/index.js'
import ModelModel, { EntryKindKeys } from '../models/Model.js'
import ModelModel, { EntryKindKeys, Settings } from '../models/Model.js'
import Model, { ModelInterface } from '../models/Model.js'
import ModelCardRevisionModel, { ModelCardRevisionDoc } from '../models/ModelCardRevision.js'
import { UserInterface } from '../models/User.js'
Expand All @@ -14,7 +14,7 @@ import { BadReq, Forbidden, NotFound } from '../utils/error.js'
import { convertStringToId } from '../utils/id.js'
import { findSchemaById } from './schema.js'

export type CreateModelParams = Pick<ModelInterface, 'name' | 'kind' | 'teamId' | 'description' | 'visibility'>
export type CreateModelParams = Pick<ModelInterface, 'name' | 'teamId' | 'description' | 'visibility' | 'settings'>
export async function createModel(user: UserInterface, modelParams: CreateModelParams) {
const modelId = convertStringToId(modelParams.name)

Expand All @@ -32,10 +32,21 @@ export async function createModel(user: UserInterface, modelParams: CreateModelP
})

const auth = await authorisation.model(user, model, ModelAction.Create)

if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn })
}

if (
(model.settings.mirror.sourceModelId && model.settings.mirror.destinationModelId) ||
(modelParams.settings &&
modelParams.settings.mirror &&
modelParams.settings.mirror.destinationModelId &&
modelParams.settings.mirror.sourceModelId)
) {
throw BadReq('You cannot select both mirror settings simultaneously.')
}

await model.save()

return model
Expand Down Expand Up @@ -207,13 +218,17 @@ export async function _setModelCard(
//
// It is assumed that this race case will occur infrequently.
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot alter a mirrored model.`)
}

const auth = await authorisation.model(user, model, ModelAction.Write)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, modelId })
}

// We don't want to copy across other values

const newDocument = {
schemaId: schemaId,

Expand All @@ -235,10 +250,8 @@ export async function updateModelCard(
metadata: unknown,
): Promise<ModelCardRevisionDoc> {
const model = await getModelById(user, modelId)

const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn, modelId })
if (model.settings.mirror.sourceModelId) {
throw BadReq(`This model card cannot be changed.`)
}

if (!model.card) {
Expand All @@ -262,19 +275,33 @@ export async function updateModelCard(
return revision
}

export type UpdateModelParams = Pick<
ModelInterface,
'name' | 'description' | 'visibility' | 'collaborators' | 'settings'
>
export async function updateModel(user: UserInterface, modelId: string, diff: Partial<UpdateModelParams>) {
export type UpdateModelParams = Pick<ModelInterface, 'name' | 'description' | 'visibility' | 'collaborators'>
export async function updateModel(
user: UserInterface,
modelId: string,
modelDiff: Partial<UpdateModelParams>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not include settings in the model? What makes settings special?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Presumably because we only use Object.assign() below, which isn't recursive. I would prefer a recursive merging here, as opposed to treating 'settings' as a special case. I think something like:

.mergeWith({}, model, modelDiff, (a, b) => 
  _.isArray(b) ? b : undefined
);

Might work.

settingsDiff?: Partial<Settings>,
) {
const model = await getModelById(user, modelId)
if (settingsDiff && settingsDiff.mirror) {
if (settingsDiff.mirror.sourceModelId) {
throw BadReq('Cannot change standard model to be a mirrored model.')
}
if (model.settings.mirror.sourceModelId && settingsDiff.mirror.destinationModelId) {
throw BadReq('Cannot set a destination model ID for a mirrored model.')
}
if (settingsDiff.mirror.destinationModelId && settingsDiff.mirror.sourceModelId) {
throw BadReq('You cannot select both mirror settings simultaneously.')
}
}

const auth = await authorisation.model(user, model, ModelAction.Update)
if (!auth.success) {
throw Forbidden(auth.info, { userDn: user.dn })
}

Object.assign(model, diff)
Object.assign(model, modelDiff)
Object.assign(model.settings, settingsDiff)
await model.save()

return model
Expand All @@ -286,6 +313,9 @@ export async function createModelCardFromSchema(
schemaId: string,
): Promise<ModelCardRevisionDoc> {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`This model card cannot be changed.`)
}

const auth = await authorisation.model(user, model, ModelAction.Write)
if (!auth.success) {
Expand Down
22 changes: 21 additions & 1 deletion backend/src/services/release.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ export type CreateReleaseParams = Optional<
>
export async function createRelease(user: UserInterface, releaseParams: CreateReleaseParams) {
const model = await getModelById(user, releaseParams.modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot create a release from a mirrored model.`)
}

if (releaseParams.modelCardVersion) {
// Ensure that the requested model card version exists.
Expand Down Expand Up @@ -175,6 +178,9 @@ export async function createRelease(user: UserInterface, releaseParams: CreateRe
export type UpdateReleaseParams = Pick<ReleaseInterface, 'notes' | 'draft' | 'fileIds' | 'images'>
export async function updateRelease(user: UserInterface, modelId: string, semver: string, delta: UpdateReleaseParams) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot update a release on a mirrored model.`)
}
const release = await getReleaseBySemver(user, modelId, semver)

Object.assign(release, delta)
Expand All @@ -198,8 +204,11 @@ export async function updateRelease(user: UserInterface, modelId: string, semver
}

export async function newReleaseComment(user: UserInterface, modelId: string, semver: string, message: string) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot create a new comment on a mirrored model.`)
}
const release = await Release.findOne({ modelId, semver })

if (!release) {
throw NotFound(`The requested release was not found.`, { modelId, semver })
}
Expand Down Expand Up @@ -234,6 +243,10 @@ export async function updateReleaseComment(
message: string,
) {
const release = await Release.findOne({ modelId, semver })
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot update comments on a mirrored model.`)
}

if (!release) {
throw NotFound(`The requested release was not found.`, { modelId, semver })
Expand Down Expand Up @@ -336,6 +349,9 @@ export async function getReleaseBySemver(user: UserInterface, modelId: string, s

export async function deleteRelease(user: UserInterface, modelId: string, semver: string) {
const model = await getModelById(user, modelId)
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot delete a release on a mirrored model.`)
}
const release = await getReleaseBySemver(user, modelId, semver)

const auth = await authorisation.release(user, model, release, ReleaseAction.Delete)
Expand All @@ -353,6 +369,10 @@ export function getReleaseName(release: ReleaseDoc): string {
}

export async function removeFileFromReleases(user: UserInterface, model: ModelDoc, fileId: string) {
if (model.settings.mirror.sourceModelId) {
throw BadReq(`Cannot remove a file from a mirrored model.`)
}

const query = {
modelId: model.id,
// Match documents where the element exists in the array
Expand Down
34 changes: 33 additions & 1 deletion backend/test/services/file.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const s3Mocks = vi.hoisted(() => ({
vi.mock('../../src/clients/s3.js', () => s3Mocks)

const modelMocks = vi.hoisted(() => ({
getModelById: vi.fn(),
getModelById: vi.fn(() => ({ settings: { mirror: { sourceModelId: '' } } })),
}))
vi.mock('../../src/services/model.js', () => modelMocks)

Expand Down Expand Up @@ -128,6 +128,24 @@ describe('services > file', () => {
)
})

test('uploadFile > should throw an error when attempting to upload a file to a mirrored model', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
const name = 'testFile'
const mime = 'text/plain'
const stream = new Readable() as any

vi.mocked(authorisation.file).mockResolvedValue({
info: 'Cannot upload files to a mirrored model.',
success: false,
id: '',
})

expect(() => uploadFile(user, modelId, name, mime, stream)).rejects.toThrowError(
/^Cannot upload files to a mirrored model./,
)
expect(fileModelMocks.save).not.toBeCalled()
})
test('removeFile > success', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
Expand Down Expand Up @@ -171,6 +189,20 @@ describe('services > file', () => {
expect(fileModelMocks.delete).not.toBeCalled()
})

test('removeFile > should throw an error when attempting to remove a file from a mirrored model', async () => {
const user = { dn: 'testUser' } as UserInterface
const modelId = 'testModelId'
const fileId = 'testFileId'

vi.mocked(authorisation.file).mockResolvedValue({
info: 'Cannot remove file from a mirrored model.',
success: false,
id: '',
})
expect(() => removeFile(user, modelId, fileId)).rejects.toThrowError(/^Cannot remove file from a mirrored model./)
expect(fileModelMocks.delete).not.toBeCalled()
})

test('getFilesByModel > success', async () => {
fileModelMocks.find.mockResolvedValueOnce([{ example: 'file' }])

Expand Down
Loading
Loading