From 6bc63f12dcfce207dacdb0a99791422a96d4423a Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Thu, 17 Oct 2024 17:34:21 +0530 Subject: [PATCH] LM Studio support --- .../github/jeddict/ai/JeddictChatModel.java | 7 + .../jeddict/ai/models/LMStudioChatModel.java | 194 ++++++++++++++++++ .../ai/models/LMStudioModelFetcher.java | 86 ++++++++ .../OllamaModelFetcher.java | 20 +- .../ai/settings/AIAssistancePanel.java | 23 ++- .../jeddict/ai/settings/GenAIProvider.java | 3 +- 6 files changed, 323 insertions(+), 10 deletions(-) create mode 100644 src/main/java/io/github/jeddict/ai/models/LMStudioChatModel.java create mode 100644 src/main/java/io/github/jeddict/ai/models/LMStudioModelFetcher.java rename src/main/java/io/github/jeddict/ai/{settings => models}/OllamaModelFetcher.java (72%) diff --git a/src/main/java/io/github/jeddict/ai/JeddictChatModel.java b/src/main/java/io/github/jeddict/ai/JeddictChatModel.java index b62eba6..ca6df92 100644 --- a/src/main/java/io/github/jeddict/ai/JeddictChatModel.java +++ b/src/main/java/io/github/jeddict/ai/JeddictChatModel.java @@ -29,6 +29,9 @@ import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; import dev.langchain4j.model.ollama.OllamaChatModel; import dev.langchain4j.model.openai.OpenAiChatModel; +import static io.github.jeddict.ai.settings.GenAIProvider.ANTHROPIC; +import static io.github.jeddict.ai.settings.GenAIProvider.OLLAMA; +import io.github.jeddict.ai.models.LMStudioChatModel; import io.github.jeddict.ai.settings.PreferencesManager; import static io.github.jeddict.ai.util.MimeUtil.MIME_TYPE_DESCRIPTIONS; import static io.github.jeddict.ai.util.StringUtil.removeCodeBlockMarkers; @@ -66,6 +69,10 @@ public JeddictChatModel() { .baseUrl(preferencesManager.getProviderLocation()) .modelName(preferencesManager.getModelName()) .build(); + case LM_STUDIO -> model = LMStudioChatModel.builder() + .baseUrl(preferencesManager.getProviderLocation()) + .modelName(preferencesManager.getModelName()) + .build(); case ANTHROPIC -> model = AnthropicChatModel.builder() .apiKey(preferencesManager.getApiKey()) .modelName(preferencesManager.getModelName()) diff --git a/src/main/java/io/github/jeddict/ai/models/LMStudioChatModel.java b/src/main/java/io/github/jeddict/ai/models/LMStudioChatModel.java new file mode 100644 index 0000000..c4a4648 --- /dev/null +++ b/src/main/java/io/github/jeddict/ai/models/LMStudioChatModel.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.github.jeddict.ai.models; + +import dev.ai4j.openai4j.OpenAiClient; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.shared.Usage; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import dev.langchain4j.model.chat.ChatLanguageModel; +import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom; +import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom; +import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions; +import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import java.time.Duration; +import static java.time.Duration.ofSeconds; +import static java.util.Collections.singletonList; +import java.util.List; + +public class LMStudioChatModel implements ChatLanguageModel { + + + public static final String LMSTUDIO_MODEL_URL = "http://localhost:1234/v1/"; + private final OpenAiClient client; + private final String modelName; + private final Double temperature; + private final Double topP; + private final Integer maxTokens; + private final Integer maxRetries; + + private LMStudioChatModel(String baseUrl, + String modelName, + Double temperature, + Double topP, + Integer maxTokens, + Duration timeout, + Integer maxRetries, + Boolean logRequests, + Boolean logResponses) { + + temperature = temperature == null ? 0.7 : temperature; + timeout = timeout == null ? ofSeconds(60) : timeout; + maxRetries = maxRetries == null ? 3 : maxRetries; + + this.client = OpenAiClient.builder() + .openAiApiKey("ignored") + .baseUrl(ensureNotBlank(baseUrl, "baseUrl")) + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout) + .logRequests(logRequests) + .logResponses(logResponses) + .build(); + this.modelName = ensureNotBlank(modelName, "modelName"); + this.temperature = temperature; + this.topP = topP; + this.maxTokens = maxTokens; + this.maxRetries = maxRetries; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String baseUrl; + private String modelName; + private Double temperature; + private Double topP; + private Integer maxTokens; + private Duration timeout; + private Integer maxRetries; + private Boolean logRequests; + private Boolean logResponses; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder logRequests(Boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(Boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public LMStudioChatModel build() { + return new LMStudioChatModel(baseUrl, modelName, temperature, topP, maxTokens, timeout, maxRetries, logRequests, logResponses); + } + } + + @Override + public Response generate(List messages) { + return generate(messages, null, null); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + return generate(messages, toolSpecifications, null); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, singletonList(toolSpecification), toolSpecification); + } + + private Response generate(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted + ) { + ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder() + .model(modelName) + .messages(toOpenAiMessages(messages)) + .temperature(temperature) + .topP(topP) + .maxTokens(maxTokens); + + if (toolSpecifications != null && !toolSpecifications.isEmpty()) { + requestBuilder.functions(toFunctions(toolSpecifications)); + } + if (toolThatMustBeExecuted != null) { + requestBuilder.functionCall(toolThatMustBeExecuted.name()); + } + + ChatCompletionRequest request = requestBuilder.build(); + + ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries); + + Usage usage = response.usage(); + + return Response.from( + aiMessageFrom(response), + new TokenUsage(usage.promptTokens(), usage.completionTokens()), + finishReasonFrom(response.choices().get(0).finishReason()) + ); + } +} diff --git a/src/main/java/io/github/jeddict/ai/models/LMStudioModelFetcher.java b/src/main/java/io/github/jeddict/ai/models/LMStudioModelFetcher.java new file mode 100644 index 0000000..f494d80 --- /dev/null +++ b/src/main/java/io/github/jeddict/ai/models/LMStudioModelFetcher.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.github.jeddict.ai.models; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; +import org.json.JSONArray; +import org.json.JSONObject; + +/** + * + * @author Shiwani Gupta + */ +public class LMStudioModelFetcher { + + private static final String API_URL = "http://localhost:1234/v1/models"; + + public String getAPIUrl() { + return API_URL; + } + + public List fetchModelNames(String apiUrl) { + List modelIds = new ArrayList<>(); + + try { + URL url = new URL(apiUrl); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Accept", "application/json"); + + int responseCode = connection.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + StringBuilder response = new StringBuilder(); + String line; + + while ((line = reader.readLine()) != null) { + response.append(line); + } + reader.close(); + + JSONObject jsonResponse = new JSONObject(response.toString()); + JSONArray models = jsonResponse.getJSONArray("data"); + + for (int i = 0; i < models.length(); i++) { + JSONObject model = models.getJSONObject(i); + String id = model.getString("id"); + modelIds.add(id); + } + } else { + System.err.println("GET request failed. Response Code: " + responseCode); + } + + } catch (Exception e) { + e.printStackTrace(); + } + + return modelIds; + } + + public static void main(String[] args) { + LMStudioModelFetcher fetcher = new LMStudioModelFetcher(); + List ids = fetcher.fetchModelNames(fetcher.getAPIUrl()); + System.out.println("Model IDs: " + ids); + } +} \ No newline at end of file diff --git a/src/main/java/io/github/jeddict/ai/settings/OllamaModelFetcher.java b/src/main/java/io/github/jeddict/ai/models/OllamaModelFetcher.java similarity index 72% rename from src/main/java/io/github/jeddict/ai/settings/OllamaModelFetcher.java rename to src/main/java/io/github/jeddict/ai/models/OllamaModelFetcher.java index 9dc909e..5f49dda 100644 --- a/src/main/java/io/github/jeddict/ai/settings/OllamaModelFetcher.java +++ b/src/main/java/io/github/jeddict/ai/models/OllamaModelFetcher.java @@ -1,8 +1,22 @@ /* - * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license - * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ -package io.github.jeddict.ai.settings; +package io.github.jeddict.ai.models; import java.io.BufferedReader; import java.io.InputStreamReader; diff --git a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java index 398af6c..595c94a 100644 --- a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java +++ b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java @@ -18,6 +18,8 @@ */ package io.github.jeddict.ai.settings; +import io.github.jeddict.ai.models.OllamaModelFetcher; +import io.github.jeddict.ai.models.LMStudioModelFetcher; import io.github.jeddict.ai.scanner.ProjectClassScanner; import java.util.List; import java.util.stream.Collectors; @@ -300,6 +302,12 @@ private void providerComboBoxActionPerformed(java.awt.event.ActionEvent evt) {// providerLocationField.setText(fetcher.getAPIUrl()); providerLocationField.setVisible(true); apiKeyField.setVisible(false); + } else if (selectedProvider == GenAIProvider.LM_STUDIO) { + providerKeyLabel.setText("Location:"); + LMStudioModelFetcher fetcher = new LMStudioModelFetcher(); + providerLocationField.setText(fetcher.getAPIUrl()); + providerLocationField.setVisible(true); + apiKeyField.setVisible(false); } if (selectedProvider != null) { updateModelComboBox(selectedProvider); @@ -320,10 +328,11 @@ private List getModelList(GenAIProvider selectedProvider) { if (selectedProvider == GenAIProvider.OLLAMA && !providerLocationField.getText().isEmpty()) { OllamaModelFetcher fetcher = new OllamaModelFetcher(); - List names = fetcher.fetchModelNames(providerLocationField.getText()); - if (!names.isEmpty()) { - return names; - } + return fetcher.fetchModelNames(providerLocationField.getText()); + } else if (selectedProvider == GenAIProvider.LM_STUDIO + && !providerLocationField.getText().isEmpty()) { + LMStudioModelFetcher fetcher = new LMStudioModelFetcher(); + return fetcher.fetchModelNames(providerLocationField.getText()); } return MODELS.values().stream() .filter(model -> model.getProvider().equals(selectedProvider)) @@ -359,7 +368,8 @@ void load() { || selectedProvider == GenAIProvider.OPEN_AI || selectedProvider == GenAIProvider.ANTHROPIC) { apiKeyField.setText(preferencesManager.getApiKey(true)); - } else if (selectedProvider == GenAIProvider.OLLAMA) { + } else if (selectedProvider == GenAIProvider.OLLAMA + || selectedProvider == GenAIProvider.LM_STUDIO) { providerLocationField.setText(preferencesManager.getProviderLocation()); } } @@ -378,7 +388,8 @@ void store() { || selectedProvider == GenAIProvider.OPEN_AI || selectedProvider == GenAIProvider.ANTHROPIC) { preferencesManager.setApiKey(new String(apiKeyField.getPassword())); - } else if (selectedProvider == GenAIProvider.OLLAMA) { + } else if (selectedProvider == GenAIProvider.OLLAMA + || selectedProvider == GenAIProvider.LM_STUDIO) { preferencesManager.setProviderLocation(providerLocationField.getText()); } } diff --git a/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java b/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java index 7a70d0d..2123b18 100644 --- a/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java +++ b/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java @@ -26,6 +26,7 @@ public enum GenAIProvider { OPEN_AI, GOOGLE, OLLAMA, - ANTHROPIC; + ANTHROPIC, + LM_STUDIO; }