Skip to content

Commit

Permalink
use inference server when model service image is equivalent to infere…
Browse files Browse the repository at this point in the history
…nce server image (#1503)

* feat: use inference server when model service image is equivalent to inference server image

Signed-off-by: lstocchi <[email protected]>

* fix: use backend to decide about inference server usage

Signed-off-by: Luca Stocchi <[email protected]>

* fix: fix failing unit tests

Signed-off-by: Jeff MAURY <[email protected]>

* fix: refactor from @axel7083 review

Signed-off-by: Jeff MAURY <[email protected]>

---------

Signed-off-by: lstocchi <[email protected]>
Signed-off-by: Luca Stocchi <[email protected]>
Signed-off-by: Jeff MAURY <[email protected]>
Co-authored-by: Luca Stocchi <[email protected]>
Co-authored-by: Jeff MAURY <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 157d4b0 commit bc20c1a
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ beforeEach(() => {
vi.resetAllMocks();

vi.mocked(webviewMock.postMessage).mockResolvedValue(true);
vi.mocked(recipeManager.buildRecipe).mockResolvedValue([recipeImageInfoMock]);
vi.mocked(recipeManager.buildRecipe).mockResolvedValue({ images: [recipeImageInfoMock] });
vi.mocked(podManager.createPod).mockResolvedValue({ engineId: 'test-engine-id', Id: 'test-pod-id' });
vi.mocked(podManager.getPod).mockResolvedValue({ engineId: 'test-engine-id', Id: 'test-pod-id' } as PodInfo);
vi.mocked(podManager.getPodsWithLabels).mockResolvedValue([]);
Expand Down Expand Up @@ -312,7 +312,7 @@ describe('pullApplication', () => {
'model-id': remoteModelMock.id,
});
// build the recipe
expect(recipeManager.buildRecipe).toHaveBeenCalledWith(connectionMock, recipeMock, {
expect(recipeManager.buildRecipe).toHaveBeenCalledWith(connectionMock, recipeMock, remoteModelMock, {
'test-label': 'test-value',
'recipe-id': recipeMock.id,
'model-id': remoteModelMock.id,
Expand Down Expand Up @@ -374,18 +374,20 @@ describe('pullApplication', () => {
test('qemu connection should have specific flag', async () => {
vi.mocked(podManager.findPodByLabelsValues).mockResolvedValue(undefined);

vi.mocked(recipeManager.buildRecipe).mockResolvedValue([
recipeImageInfoMock,
{
modelService: true,
ports: ['8888'],
name: 'llamacpp',
id: 'llamacpp',
appName: 'llamacpp',
engineId: recipeImageInfoMock.engineId,
recipeId: recipeMock.id,
},
]);
vi.mocked(recipeManager.buildRecipe).mockResolvedValue({
images: [
recipeImageInfoMock,
{
modelService: true,
ports: ['8888'],
name: 'llamacpp',
id: 'llamacpp',
appName: 'llamacpp',
engineId: recipeImageInfoMock.engineId,
recipeId: recipeMock.id,
},
],
});

await getInitializedApplicationManager().pullApplication(connectionMock, recipeMock, remoteModelMock);

Expand Down
28 changes: 16 additions & 12 deletions packages/backend/src/managers/application/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import type { Recipe, RecipeImage } from '@shared/src/models/IRecipe';
import type { Recipe, RecipeComponents, RecipeImage } from '@shared/src/models/IRecipe';
import * as path from 'node:path';
import { containerEngine, Disposable, window, ProgressLocation } from '@podman-desktop/api';
import type {
Expand Down Expand Up @@ -187,7 +187,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
});

// build all images, one per container (for a basic sample we should have 2 containers = sample app + model service)
const images = await this.recipeManager.buildRecipe(connection, recipe, {
const recipeComponents = await this.recipeManager.buildRecipe(connection, recipe, model, {
...labels,
'recipe-id': recipe.id,
'model-id': model.id,
Expand All @@ -199,7 +199,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
}

// create a pod containing all the containers to run the application
return this.createApplicationPod(connection, recipe, model, images, modelPath, {
return this.createApplicationPod(connection, recipe, model, recipeComponents, modelPath, {
...labels,
'recipe-id': recipe.id,
'model-id': model.id,
Expand Down Expand Up @@ -253,7 +253,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
connection: ContainerProviderConnection,
recipe: Recipe,
model: ModelInfo,
images: RecipeImage[],
components: RecipeComponents,
modelPath: string,
labels?: { [key: string]: string },
): Promise<PodInfo> {
Expand All @@ -262,7 +262,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
// create empty pod
let podInfo: PodInfo;
try {
podInfo = await this.createPod(connection, recipe, model, images);
podInfo = await this.createPod(connection, recipe, model, components.images);
task.labels = {
...task.labels,
'pod-id': podInfo.Id,
Expand All @@ -277,7 +277,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
}

try {
await this.createContainerAndAttachToPod(connection, podInfo, images, model, modelPath);
await this.createContainerAndAttachToPod(connection, podInfo, components, model, modelPath);
task.state = 'success';
} catch (e) {
console.error(`error when creating pod ${podInfo.Id}`, e);
Expand All @@ -294,14 +294,14 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
protected async createContainerAndAttachToPod(
connection: ContainerProviderConnection,
podInfo: PodInfo,
images: RecipeImage[],
components: RecipeComponents,
modelInfo: ModelInfo,
modelPath: string,
): Promise<void> {
const vmType = connection.vmType ?? VMType.UNKNOWN;
// temporary check to set Z flag or not - to be removed when switching to podman 5
await Promise.all(
images.map(async image => {
components.images.map(async image => {
let hostConfig: HostConfig | undefined = undefined;
let envs: string[] = [];
let healthcheck: HealthConfig | undefined = undefined;
Expand All @@ -321,11 +321,15 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
envs = [`MODEL_PATH=/${modelName}`];
envs.push(...getModelPropertiesForEnvironment(modelInfo));
} else {
// TODO: remove static port
const modelService = images.find(image => image.modelService);
if (modelService && modelService.ports.length > 0) {
const endPoint = `http://localhost:${modelService.ports[0]}`;
if (components.inferenceServer) {
const endPoint = `http://host.containers.internal:${components.inferenceServer.connection.port}`;
envs = [`MODEL_ENDPOINT=${endPoint}`];
} else {
const modelService = components.images.find(image => image.modelService);
if (modelService && modelService.ports.length > 0) {
const endPoint = `http://localhost:${modelService.ports[0]}`;
envs = [`MODEL_ENDPOINT=${endPoint}`];
}
}
}
if (image.ports.length > 0) {
Expand Down
16 changes: 16 additions & 0 deletions packages/backend/src/managers/inference/inferenceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,22 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
return this.#servers.get(containerId);
}

/**
* return the first inference server which is using the specific model
* it throws if the model backend is not currently supported
*/
public findServerByModel(model: ModelInfo): InferenceServer | undefined {
// check if model backend is supported
const backend: InferenceType = getInferenceType([model]);
const providers: InferenceProvider[] = this.inferenceProviderRegistry
.getByType(backend)
.filter(provider => provider.enabled());
if (providers.length === 0) {
throw new Error('no enabled provider could be found.');
}
return this.getServers().find(s => s.models.some(m => m.id === model.id));
}

/**
* Creating an inference server can be heavy task (pulling image, uploading model to WSL etc.)
* The frontend cannot wait endlessly, therefore we provide a method returning a tracking identifier
Expand Down
15 changes: 13 additions & 2 deletions packages/backend/src/managers/recipes/RecipeManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import { existsSync, statSync } from 'node:fs';
import { AIConfigFormat, parseYamlFile } from '../../models/AIConfig';
import { goarch } from '../../utils/arch';
import { VMType } from '@shared/src/models/IPodman';
import type { InferenceManager } from '../inference/inferenceManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';

const taskRegistryMock = {
createTask: vi.fn(),
Expand All @@ -46,6 +48,8 @@ const localRepositoriesMock = {
register: vi.fn(),
} as unknown as LocalRepositoryRegistry;

const inferenceManagerMock = {} as unknown as InferenceManager;

const recipeMock: Recipe = {
id: 'recipe-test',
name: 'Test Recipe',
Expand All @@ -60,6 +64,12 @@ const connectionMock: ContainerProviderConnection = {
vmType: VMType.UNKNOWN,
} as unknown as ContainerProviderConnection;

const modelInfoMock: ModelInfo = {
id: 'modelId',
name: 'Model',
description: 'model to test',
} as unknown as ModelInfo;

vi.mock('../../models/AIConfig', () => ({
AIConfigFormat: {
CURRENT: 'current',
Expand Down Expand Up @@ -123,6 +133,7 @@ async function getInitializedRecipeManager(): Promise<RecipeManager> {
taskRegistryMock,
builderManagerMock,
localRepositoriesMock,
inferenceManagerMock,
);
manager.init();
return manager;
Expand Down Expand Up @@ -180,14 +191,14 @@ describe('buildRecipe', () => {
const manager = await getInitializedRecipeManager();

await expect(() => {
return manager.buildRecipe(connectionMock, recipeMock);
return manager.buildRecipe(connectionMock, recipeMock, modelInfoMock);
}).rejects.toThrowError('build error');
});

test('labels should be propagated', async () => {
const manager = await getInitializedRecipeManager();

await manager.buildRecipe(connectionMock, recipeMock, {
await manager.buildRecipe(connectionMock, recipeMock, modelInfoMock, {
'test-label': 'test-value',
});

Expand Down
77 changes: 69 additions & 8 deletions packages/backend/src/managers/recipes/RecipeManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
***********************************************************************/
import type { GitCloneInfo, GitManager } from '../gitManager';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import type { Recipe, RecipeImage } from '@shared/src/models/IRecipe';
import type { Recipe, RecipeComponents } from '@shared/src/models/IRecipe';
import path from 'node:path';
import type { Task } from '@shared/src/models/ITask';
import type { LocalRepositoryRegistry } from '../../registries/LocalRepositoryRegistry';
Expand All @@ -28,6 +28,10 @@ import { goarch } from '../../utils/arch';
import type { BuilderManager } from './BuilderManager';
import type { ContainerProviderConnection, Disposable } from '@podman-desktop/api';
import { CONFIG_FILENAME } from '../../utils/RecipeConstants';
import type { InferenceManager } from '../inference/inferenceManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { withDefaultConfiguration } from '../../utils/inferenceUtils';
import type { InferenceServer } from '@shared/src/models/IInference';

export interface AIContainers {
aiConfigFile: AIConfigFile;
Expand All @@ -41,6 +45,7 @@ export class RecipeManager implements Disposable {
private taskRegistry: TaskRegistry,
private builderManager: BuilderManager,
private localRepositories: LocalRepositoryRegistry,
private inferenceManager: InferenceManager,
) {}

dispose(): void {}
Expand Down Expand Up @@ -94,17 +99,63 @@ export class RecipeManager implements Disposable {
public async buildRecipe(
connection: ContainerProviderConnection,
recipe: Recipe,
model: ModelInfo,
labels?: { [key: string]: string },
): Promise<RecipeImage[]> {
): Promise<RecipeComponents> {
const localFolder = path.join(this.appUserDirectory, recipe.id);

let inferenceServer: InferenceServer | undefined;
// if the recipe has a defined backend, we gives priority to using an inference server
if (recipe.backend && recipe.backend === model.backend) {
let task: Task | undefined;
try {
inferenceServer = this.inferenceManager.findServerByModel(model);
task = this.taskRegistry.createTask('Starting Inference server', 'loading', labels);
if (!inferenceServer) {
const inferenceContainerId = await this.inferenceManager.createInferenceServer(
await withDefaultConfiguration({
modelsInfo: [model],
}),
);
inferenceServer = this.inferenceManager.get(inferenceContainerId);
this.taskRegistry.updateTask({
...task,
labels: {
...task.labels,
containerId: inferenceContainerId,
},
});
} else if (inferenceServer.status === 'stopped') {
await this.inferenceManager.startInferenceServer(inferenceServer.container.containerId);
}
task.state = 'success';
} catch (e) {
// we only skip the task update if the error is that we do not support this backend.
// If so, we build the image for the model service
if (task && String(e) !== 'no enabled provider could be found.') {
task.state = 'error';
task.error = `Something went wrong while starting the inference server: ${String(e)}`;
throw e;
}
} finally {
if (task) {
this.taskRegistry.updateTask(task);
}
}
}

// load and parse the recipe configuration file and filter containers based on architecture
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.basedir, localFolder, {
...labels,
'recipe-id': recipe.id,
});
const configAndFilteredContainers = this.getConfigAndFilterContainers(
recipe.basedir,
localFolder,
!!inferenceServer,
{
...labels,
'recipe-id': recipe.id,
},
);

return await this.builderManager.build(
const images = await this.builderManager.build(
connection,
recipe,
configAndFilteredContainers.containers,
Expand All @@ -114,11 +165,17 @@ export class RecipeManager implements Disposable {
'recipe-id': recipe.id,
},
);

return {
images,
inferenceServer,
};
}

private getConfigAndFilterContainers(
recipeBaseDir: string | undefined,
localFolder: string,
useInferenceServer: boolean,
labels?: { [key: string]: string },
): AIContainers {
// Adding loading configuration task
Expand All @@ -135,7 +192,11 @@ export class RecipeManager implements Disposable {
}

// filter the containers based on architecture, gpu accelerator and backend (that define which model supports)
const filteredContainers: ContainerConfig[] = this.filterContainers(aiConfigFile.aiConfig);
let filteredContainers: ContainerConfig[] = this.filterContainers(aiConfigFile.aiConfig);
// if we are using the inference server we can remove the model service
if (useInferenceServer) {
filteredContainers = filteredContainers.filter(c => !c.modelService);
}
if (filteredContainers.length > 0) {
// Mark as success.
task.state = 'success';
Expand Down
2 changes: 2 additions & 0 deletions packages/backend/src/models/AIConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export interface ContainerConfig {
gpu_env: string[];
ports?: number[];
image?: string;
backend?: string[];
}

export enum AIConfigFormat {
Expand Down Expand Up @@ -130,6 +131,7 @@ export function parseYamlFile(filepath: string, defaultArch: string): AIConfig {
? container['ports'].map(port => parseInt(port))
: [],
image: 'image' in container && isString(container['image']) ? container['image'] : undefined,
backend: 'backend' in container && Array.isArray(container['backend']) ? container['backend'] : undefined,
};
}),
},
Expand Down
Loading

0 comments on commit bc20c1a

Please sign in to comment.