Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support: OpenAI base url #5

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/components/modal-auth/modal-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Model = 'palm' | 'gpt';
export interface ModelAuthMessage {
model: Model;
apiKey: string;
baseURL: string;
}

/**
Expand All @@ -26,6 +27,9 @@ export class WordflowModalAuth extends LitElement {
@property({ type: String })
apiKey = '';

@property({ type: String })
baseURL = '';

@query('.modal-auth')
modalElement!: HTMLElement | null;

Expand Down Expand Up @@ -63,6 +67,10 @@ export class WordflowModalAuth extends LitElement {
const apiKey = USE_CACHE
? localStorage.getItem(`${model}APIKey`)
: null;

const baseURL = USE_CACHE
? localStorage.getItem(`${model}baseURL`)
: null;

if (apiKey === null) {
this.modelSetMap[model] = false;
Expand All @@ -71,7 +79,8 @@ export class WordflowModalAuth extends LitElement {
const event = new CustomEvent<ModelAuthMessage>('api-key-added', {
detail: {
model,
apiKey
apiKey,
baseURL: baseURL || 'https://api.openai.com/v1'
}
});
this.dispatchEvent(event);
Expand Down Expand Up @@ -120,17 +129,20 @@ export class WordflowModalAuth extends LitElement {
authVerificationSucceeded = (
model: Model,
messageElement: HTMLElement,
apiKey: string
apiKey: string,
baseURL: string
) => {
// Add the api key to the local storage
if (USE_CACHE) {
localStorage.setItem(`${model}APIKey`, apiKey);
localStorage.setItem(`${model}baseURL`, baseURL);
}

const event = new CustomEvent<ModelAuthMessage>('api-key-added', {
detail: {
model,
apiKey
apiKey,
baseURL
}
});
this.dispatchEvent(event);
Expand All @@ -150,6 +162,9 @@ export class WordflowModalAuth extends LitElement {
const apiInputElement = this.renderRoot.querySelector<HTMLInputElement>(
`#api-input-${model}`
);
const baseURLInputElement = this.renderRoot.querySelector<HTMLInputElement>(
`#base-url-input-${model}`
);

if (messageElement === null || apiInputElement === null) {
throw Error("Can't locate the input elements");
Expand All @@ -160,6 +175,8 @@ export class WordflowModalAuth extends LitElement {
return;
}

const baseURL = baseURLInputElement?.value || 'https://api.openai.com/v1';

// Start to verify the given key
messageElement.classList.remove('error');
messageElement.classList.remove('success');
Expand All @@ -185,6 +202,7 @@ export class WordflowModalAuth extends LitElement {
case 'gpt': {
textGenGpt(
apiKey,
baseURL,
requestID,
prompt,
temperature,
Expand Down Expand Up @@ -221,7 +239,8 @@ export class WordflowModalAuth extends LitElement {
this.authVerificationSucceeded(
model,
messageElement,
message.payload.apiKey
message.payload.apiKey,
message.payload.baseURL
);
}
break;
Expand Down Expand Up @@ -294,6 +313,7 @@ export class WordflowModalAuth extends LitElement {
</span>
<div class="input-form">
<input id="api-input-gpt" placeholder="API Key" />
<input id="base-url-input-gpt" placeholder="base URL" />
<button
class="primary"
@click="${() => this.submitButtonClicked('gpt')}"
Expand Down
51 changes: 46 additions & 5 deletions src/components/panel-setting/panel-setting.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ export class WordflowPanelSetting extends LitElement {
@state()
apiInputValue = '';

@state()
baseURLInputValue = '';

@state()
showModelLoader = false;

Expand Down Expand Up @@ -214,6 +217,13 @@ export class WordflowPanelSetting extends LitElement {
this.curDeviceSupportsLocalModel = false;
this.localModelMessage = LOCAL_MODEL_MESSAGES.incompatible;
}

// Initialize the Base URL input
const inputElement = this.shadowRoot?.querySelector('#base-url-input') as HTMLInputElement;
if (inputElement) {
const event = new Event('input', { bubbles: true, composed: true });
inputElement.dispatchEvent(event);
}
}

//==========================================================================||
Expand All @@ -225,6 +235,7 @@ export class WordflowPanelSetting extends LitElement {
textGenMessageHandler = (
model: SupportedRemoteModel,
apiKey: string,
baseURL: string,
message: TextGenMessage
) => {
switch (message.command) {
Expand All @@ -237,6 +248,10 @@ export class WordflowPanelSetting extends LitElement {
// Add the api key to the storage
this.userConfigManager.setAPIKey(modelFamily, apiKey);

// Add the base URL to the storage
console.log('baseURL'+baseURL);
this.userConfigManager.setBaseURL(modelFamily, baseURL);

// Also use set this model as preferred model
if (this.selectedModel === model) {
this.userConfigManager.setPreferredLLM(model);
Expand Down Expand Up @@ -339,13 +354,22 @@ export class WordflowPanelSetting extends LitElement {
e.preventDefault();

if (
(// Check if the API key is the same as the one in the storage
this.userConfig.llmAPIKeys[this.selectedModelFamily] ===
this.apiInputValue ||
this.apiInputValue === ''
this.apiInputValue === ''
) &&
(// Check if the base URL is the same as the one in the storage
this.baseURLInputValue === this.userConfig.baseURL[this.selectedModelFamily]
|| this.baseURLInputValue === '')
) {
return;
}

// Check if the user has set the API key for the preferred model
const do_not_update_key = this.baseURLInputValue !== this.userConfig.baseURL[this.selectedModelFamily] &&
this.userConfig.llmAPIKeys[this.selectedModelFamily] === this.apiInputValue;

if (this.shadowRoot === null) {
throw Error('shadowRoot is null');
}
Expand All @@ -354,7 +378,7 @@ export class WordflowPanelSetting extends LitElement {
const temperature = 0.8;

// Parse the api key
const apiKey = this.apiInputValue;
const apiKey = do_not_update_key ? this.userConfig.llmAPIKeys[this.selectedModelFamily] : this.apiInputValue;
this.showModelLoader = true;

switch (this.selectedModelFamily) {
Expand All @@ -365,6 +389,7 @@ export class WordflowPanelSetting extends LitElement {
this.textGenMessageHandler(
this.selectedModel as SupportedRemoteModel,
apiKey,
this.baseURLInputValue,
value
);
}
Expand All @@ -375,6 +400,7 @@ export class WordflowPanelSetting extends LitElement {
case ModelFamily.openAI: {
textGenGpt(
apiKey,
this.baseURLInputValue,
requestID,
prompt,
temperature,
Expand All @@ -385,6 +411,7 @@ export class WordflowPanelSetting extends LitElement {
this.textGenMessageHandler(
this.selectedModel as SupportedRemoteModel,
apiKey,
this.baseURLInputValue,
value
);
});
Expand Down Expand Up @@ -609,12 +636,25 @@ export class WordflowPanelSetting extends LitElement {
value="${this.userConfig.llmAPIKeys[
this.selectedModelFamily
]}"
placeholder=""
placeholder="sk-xxxx"
@input=${(e: InputEvent) => {
const element = e.currentTarget as HTMLInputElement;
this.apiInputValue = element.value;
}}
/>
<input
type="text"
class="content-text api-input"
id="base-url-input"
value="${this.userConfig.baseURL[
this.selectedModelFamily
]}"
placeholder="(Optional) base url, default https://api.openai.com/v1"
@input=${(e: InputEvent) => {
const element = e.currentTarget as HTMLInputElement;
this.baseURLInputValue = element.value;
}}
/>
<div class="right-loader">
<div
class="prompt-loader"
Expand All @@ -628,9 +668,10 @@ export class WordflowPanelSetting extends LitElement {
</div>
<button
class="add-button"
?has-set=${this.userConfig.llmAPIKeys[
?has-set=${(this.userConfig.llmAPIKeys[
this.selectedModelFamily
] === this.apiInputValue || this.apiInputValue === ''}
] === this.apiInputValue || this.apiInputValue === '') &&
this.baseURLInputValue === this.userConfig.baseURL[this.selectedModelFamily]}
@click=${(e: MouseEvent) => this.addButtonClicked(e)}
>
${this.userConfig.llmAPIKeys[this.selectedModelFamily] === ''
Expand Down
3 changes: 3 additions & 0 deletions src/components/text-editor/text-editor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ export class WordflowTextEditor extends LitElement {
case SupportedRemoteModel['gpt-3.5']: {
runRequest = textGenGpt(
this.userConfig.llmAPIKeys[ModelFamily.openAI],
this.userConfig.baseURL[ModelFamily.openAI],
'text-gen',
curPrompt,
promptData.temperature,
Expand All @@ -982,12 +983,14 @@ export class WordflowTextEditor extends LitElement {
case SupportedRemoteModel['gpt-4']: {
runRequest = textGenGpt(
this.userConfig.llmAPIKeys[ModelFamily.openAI],
this.userConfig.baseURL[ModelFamily.openAI],
'text-gen',
curPrompt,
promptData.temperature,
'gpt-4-1106-preview',
USE_CACHE
);
console.log('Running GPT-4' + this.userConfig.baseURL);
break;
}

Expand Down
14 changes: 14 additions & 0 deletions src/components/wordflow/user-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export const modelFamilyMap: Record<

export interface UserConfig {
llmAPIKeys: Record<ModelFamily, string>;
baseURL: Record<ModelFamily, string>;
preferredLLM: SupportedRemoteModel | SupportedLocalModel;
}

Expand All @@ -63,6 +64,7 @@ export class UserConfigManager {
updateUserConfig: (userConfig: UserConfig) => void;

#llmAPIKeys: Record<ModelFamily, string>;
#baseURL: Record<ModelFamily, string>;
#preferredLLM: SupportedRemoteModel | SupportedLocalModel;

constructor(updateUserConfig: (userConfig: UserConfig) => void) {
Expand All @@ -73,6 +75,11 @@ export class UserConfigManager {
[ModelFamily.google]: '',
[ModelFamily.local]: ''
};
this.#baseURL = {
[ModelFamily.openAI]: 'https://api.openai.com/v1',
[ModelFamily.google]: 'https://gemini-prod.googleapis.com/v1',
[ModelFamily.local]: ''
};
this.#preferredLLM = SupportedRemoteModel['gpt-3.5-free'];
this._broadcastUserConfig();

Expand All @@ -87,6 +94,12 @@ export class UserConfigManager {
this._broadcastUserConfig();
}

setBaseURL(modelFamily: ModelFamily, url: string) {
this.#baseURL[modelFamily] = url;
this._syncStorage();
this._broadcastUserConfig();
}

setPreferredLLM(model: SupportedRemoteModel | SupportedLocalModel) {
this.#preferredLLM = model;
this._syncStorage();
Expand Down Expand Up @@ -121,6 +134,7 @@ export class UserConfigManager {
_constructConfig(): UserConfig {
const config: UserConfig = {
llmAPIKeys: this.#llmAPIKeys,
baseURL: this.#baseURL,
preferredLLM: this.#preferredLLM
};
return config;
Expand Down
2 changes: 2 additions & 0 deletions src/llms/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const textGenGemini = async (
payload: {
requestID,
apiKey,
baseURL: '',
result: cachedValue,
prompt: prompt,
detail: detail
Expand Down Expand Up @@ -115,6 +116,7 @@ export const textGenGemini = async (
payload: {
requestID,
apiKey,
baseURL: '',
result,
prompt: prompt,
detail: detail
Expand Down
8 changes: 7 additions & 1 deletion src/llms/gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type TextGenMessage =
payload: {
requestID: string;
apiKey: string;
baseURL: string;
result: string;
prompt: string;
detail: string;
Expand All @@ -27,6 +28,7 @@ export type TextGenMessage =
/**
* Use GPT API to generate text based on a given prompt
* @param apiKey GPT API key
* @param baseURL Base URL for the GPT API
* @param requestID Worker request ID
* @param prompt Prompt to give to the GPT model
* @param temperature Model temperature
Expand All @@ -36,6 +38,7 @@ export type TextGenMessage =
*/
export const textGenGpt = async (
apiKey: string,
baseURL: string,
requestID: string,
prompt: string,
temperature: number,
Expand Down Expand Up @@ -68,6 +71,7 @@ export const textGenGpt = async (
payload: {
requestID,
apiKey,
baseURL,
result: cachedValue,
prompt: prompt,
detail: detail
Expand All @@ -76,7 +80,7 @@ export const textGenGpt = async (
return message;
}

const url = 'https://api.openai.com/v1/chat/completions';
// const url = 'https://api.openai.com/v1/chat/completions';

const requestOptions: RequestInit = {
method: 'POST',
Expand All @@ -89,6 +93,7 @@ export const textGenGpt = async (
};

try {
const url = baseURL + '/chat/completions';
const response = await fetch(url, requestOptions);
const data = (await response.json()) as ChatCompletion;
if (response.status !== 200) {
Expand All @@ -105,6 +110,7 @@ export const textGenGpt = async (
payload: {
requestID,
apiKey,
baseURL,
result: data.choices[0].message.content,
prompt: prompt,
detail: detail
Expand Down
Loading