Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] fold openai support into provider param #1205

Merged
merged 12 commits into from
Feb 28, 2025
51 changes: 38 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import { OPENAI_CONFIG } from "../providers/openai";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
Expand All @@ -31,6 +32,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
openai: OPENAI_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
Expand Down Expand Up @@ -70,22 +72,38 @@ export async function makeRequestOptions(
if (!providerConfig) {
throw new Error(`No provider config found for provider ${provider}`);
}
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});
const model = providerConfig.clientSideRoutingOnly
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
removeProviderPrefix(maybeModel!, provider)
: // For closed-models API providers, one needs to pass the model ID directly (e.g. "gpt-3.5-turbo")
await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
const authMethod = (() => {
if (providerConfig.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
}
return "provider-key";
}
if (accessToken) {
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
}
if (includeCredentials === "include") {
// If accessToken is passed, it should take precedence over includeCredentials
return "credentials-include";
}
return "none";
})();

// Make URL
const url = endpointUrl
Expand Down Expand Up @@ -174,3 +192,10 @@ async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[]
}
return await res.json();
}

function removeProviderPrefix(model: string, provider: string): string {
if (!model.startsWith(`${provider}/`)) {
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
}
return model.slice(provider.length + 1);
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
hyperbolic: {},
nebius: {},
novita: {},
openai: {},
replicate: {},
sambanova: {},
together: {},
Expand Down
35 changes: 35 additions & 0 deletions packages/inference/src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Special case: provider configuration for a private models provider (OpenAI in this case).
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const OPENAI_API_BASE_URL = "https://api.openai.com";

const makeBody = (params: BodyParams): Record<string, unknown> => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return `${params.baseUrl}/v1/chat/completions`;
};

export const OPENAI_CONFIG: ProviderConfig = {
baseUrl: OPENAI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
clientSideRoutingOnly: true,
};
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const INFERENCE_PROVIDERS = [
"hyperbolic",
"nebius",
"novita",
"openai",
"replicate",
"sambanova",
"together",
Expand Down Expand Up @@ -96,6 +97,7 @@ export interface ProviderConfig {
makeBody: (params: BodyParams) => Record<string, unknown>;
makeHeaders: (params: HeaderParams) => Record<string, string>;
makeUrl: (params: UrlParams) => string;
clientSideRoutingOnly?: boolean;
}

export interface HeaderParams {
Expand Down
15 changes: 12 additions & 3 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ describe.concurrent("HfInference", () => {
it("custom openai - OpenAI Specs", async () => {
const OPENAI_KEY = env.OPENAI_KEY;
const hf = new HfInference(OPENAI_KEY);
const ep = hf.endpoint("https://api.openai.com");
const stream = ep.chatCompletionStream({
model: "gpt-3.5-turbo",
const stream = hf.chatCompletionStream({
provider: "openai",
model: "openai/gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the equation one + one =" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
Expand All @@ -768,6 +768,15 @@ describe.concurrent("HfInference", () => {
}
expect(out).toContain("two");
});
it("OpenAI client side routing - model should have provider as prefix", async () => {
await expect(
new HfInference("dummy_token").chatCompletion({
model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo"
provider: "openai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
})
).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`);
});
},
TIMEOUT
);
Expand Down
31 changes: 31 additions & 0 deletions packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -7386,5 +7386,36 @@
"content-type": "image/jpeg"
}
}
},
"ad463aeaa0a4222600cd3fe0ad34ec1bbee5f4fa9a12beeb40fac29922c7e6a5": {
"url": "https://api.openai.com/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"gpt-3.5-turbo\"}"
},
"response": {
"body": "{\"id\":\"chatcmpl-B5Ybe7M1V7N9aqQtrjTGfSWGebrrb\",\"object\":\"chat.completion\",\"created\":1740664366,\"model\":\"gpt-3.5-turbo-0125\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"to two.\",\"refusal\":null},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":19,\"completion_tokens\":4,\"total_tokens\":23,\"prompt_tokens_details\":{\"cached_tokens\":0,\"audio_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":0,\"audio_tokens\":0,\"accepted_prediction_tokens\":0,\"rejected_prediction_tokens\":0}},\"service_tier\":\"default\",\"system_fingerprint\":null}",
"status": 200,
"statusText": "OK",
"headers": {
"access-control-expose-headers": "X-Request-ID",
"alt-svc": "h3=\":443\"; ma=86400",
"cf-cache-status": "DYNAMIC",
"cf-ray": "9188a84109a36f48-CDG",
"connection": "keep-alive",
"content-encoding": "br",
"content-type": "application/json",
"openai-organization": "user-b0yxwesrrz0borhxlqfmuwqf",
"openai-processing-ms": "249",
"openai-version": "2020-10-01",
"server": "cloudflare",
"set-cookie": "_cfuvid=PTKpHnf8DDZxJ.gVO5N_KdQ0qmMr8idtxlnmeGX05kw-1740664366628-0.0.1.1-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None",
"strict-transport-security": "max-age=31536000; includeSubDomains; preload",
"transfer-encoding": "chunked"
}
}
}
}