Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] fold openai support into provider param #1205

Merged
merged 12 commits into from
Feb 28, 2025
34 changes: 21 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
import { REPLICATE_CONFIG } from "../providers/replicate";
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
import { TOGETHER_CONFIG } from "../providers/together";
import { OPENAI_CONFIG } from "../providers/openai";
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
import { version as packageVersion, name as packageName } from "../../package.json";
Expand All @@ -31,6 +32,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
hyperbolic: HYPERBOLIC_CONFIG,
openai: OPENAI_CONFIG,
nebius: NEBIUS_CONFIG,
novita: NOVITA_CONFIG,
replicate: REPLICATE_CONFIG,
Expand Down Expand Up @@ -72,20 +74,26 @@ export async function makeRequestOptions(
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const model = await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});
const model = providerConfig.closedSource
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
maybeModel! //For closed-models API providers, one needs to pass the model ID directly
: await getProviderModelId({ model: hfModel, provider }, args, {
task,
chatCompletion,
fetch: options?.fetch,
});

/// If accessToken is passed, it should take precedence over includeCredentials
const authMethod = accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
// If accessToken is passed, it should take precedence over includeCredentials
// Closed-source providers require an accessToken (cannot be routed).
const authMethod = providerConfig.closedSource
? "provider-key"
: accessToken
? accessToken.startsWith("hf_")
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";

// Make URL
const url = endpointUrl
Expand Down
35 changes: 35 additions & 0 deletions packages/inference/src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Special case: provider configuration for a private models provider (OpenAI in this case).
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const OPENAI_API_BASE_URL = "https://api.openai.com";

const makeBody = (params: BodyParams): Record<string, unknown> => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
if (!params.chatCompletion) {
throw new Error("OpenAI only supports chat completions.");
}
return `${params.baseUrl}/v1/chat/completions`;
};

export const OPENAI_CONFIG: ProviderConfig = {
baseUrl: OPENAI_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
closedSource: true,
};
2 changes: 2 additions & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const INFERENCE_PROVIDERS = [
"hyperbolic",
"nebius",
"novita",
"openai",
"replicate",
"sambanova",
"together",
Expand Down Expand Up @@ -95,6 +96,7 @@ export interface ProviderConfig {
makeBody: (params: BodyParams) => Record<string, unknown>;
makeHeaders: (params: HeaderParams) => Record<string, string>;
makeUrl: (params: UrlParams) => string;
closedSource?: boolean;
}

export interface HeaderParams {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ describe.concurrent("HfInference", () => {
const hf = new HfInference(OPENAI_KEY);
const stream = hf.chatCompletionStream({
provider: "openai",
model: "openai/gpt-3.5-turbo",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm wondering if we should still namespace model names and strip it in the implem? (just to keep consistency with HF repos..) 🤷

Copy link
Contributor

@Wauplin Wauplin Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit redundant but ok to switch if you prefer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(what if they OS it at some point in the future? etc)

model: "gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the equation one + one =" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
Expand Down
Loading