Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] fold openai support into provider param #1205

Merged
merged 12 commits into from
Feb 28, 2025
34 changes: 21 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import { OPENAI_CONFIG } from "../providers/openai";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
Expand All @@ -31,6 +32,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
openai: OPENAI_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
Expand Down Expand Up @@ -72,20 +74,26 @@ export async function makeRequestOptions(
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});
const model = providerConfig.closedSource
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
maybeModel! //For closed-models API providers, one needs to pass the model ID directly
: await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
// If accessToken is passed, it should take precedence over includeCredentials
// Closed-source providers require an accessToken (cannot be routed).
const authMethod = providerConfig.closedSource
? "provider-key"
: accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";

// Make URL
const url = endpointUrl
Expand Down
35 changes: 35 additions & 0 deletions packages/inference/src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Special case: provider configuration for a private models provider (OpenAI in this case).
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const OPENAI_API_BASE_URL = "https://api.openai.com";

const makeBody = (params: BodyParams): Record<string, unknown> => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return `${params.baseUrl}/v1/chat/completions`;
};

export const OPENAI_CONFIG: ProviderConfig = {
baseUrl: OPENAI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
closedSource: true,
};
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const INFERENCE_PROVIDERS = [
"hyperbolic",
"nebius",
"novita",
"openai",
"replicate",
"sambanova",
"together",
Expand Down Expand Up @@ -95,6 +96,7 @@ export interface ProviderConfig {
makeBody: (params: BodyParams) => Record<string, unknown>;
makeHeaders: (params: HeaderParams) => Record<string, string>;
makeUrl: (params: UrlParams) => string;
closedSource?: boolean;
}

export interface HeaderParams {
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,8 @@ describe.concurrent("HfInference", () => {
it("custom openai - OpenAI Specs", async () => {
const OPENAI_KEY = env.OPENAI_KEY;
const hf = new HfInference(OPENAI_KEY);
const ep = hf.endpoint("https://api.openai.com");
const stream = ep.chatCompletionStream({
const stream = hf.chatCompletionStream({
provider: "openai",
model: "gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the equation one + one =" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
Expand Down
Loading