Skip to content

Commit

Permalink
LM Studio support
Browse files Browse the repository at this point in the history
  • Loading branch information
jShiwaniGupta committed Oct 17, 2024
1 parent 5fdeefb commit 6bc63f1
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 10 deletions.
7 changes: 7 additions & 0 deletions src/main/java/io/github/jeddict/ai/JeddictChatModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import static io.github.jeddict.ai.settings.GenAIProvider.ANTHROPIC;
import static io.github.jeddict.ai.settings.GenAIProvider.OLLAMA;
import io.github.jeddict.ai.models.LMStudioChatModel;
import io.github.jeddict.ai.settings.PreferencesManager;
import static io.github.jeddict.ai.util.MimeUtil.MIME_TYPE_DESCRIPTIONS;
import static io.github.jeddict.ai.util.StringUtil.removeCodeBlockMarkers;
Expand Down Expand Up @@ -66,6 +69,10 @@ public JeddictChatModel() {
.baseUrl(preferencesManager.getProviderLocation())
.modelName(preferencesManager.getModelName())
.build();
case LM_STUDIO -> model = LMStudioChatModel.builder()
.baseUrl(preferencesManager.getProviderLocation())
.modelName(preferencesManager.getModelName())
.build();
case ANTHROPIC -> model = AnthropicChatModel.builder()
.apiKey(preferencesManager.getApiKey())
.modelName(preferencesManager.getModelName())
Expand Down
194 changes: 194 additions & 0 deletions src/main/java/io/github/jeddict/ai/models/LMStudioChatModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.github.jeddict.ai.models;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import dev.langchain4j.model.chat.ChatLanguageModel;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.aiMessageFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toFunctions;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.time.Duration;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.singletonList;
import java.util.List;

public class LMStudioChatModel implements ChatLanguageModel {


public static final String LMSTUDIO_MODEL_URL = "http://localhost:1234/v1/";
private final OpenAiClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;
private final Integer maxRetries;

private LMStudioChatModel(String baseUrl,
String modelName,
Double temperature,
Double topP,
Integer maxTokens,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses) {

temperature = temperature == null ? 0.7 : temperature;
timeout = timeout == null ? ofSeconds(60) : timeout;
maxRetries = maxRetries == null ? 3 : maxRetries;

this.client = OpenAiClient.builder()
.openAiApiKey("ignored")
.baseUrl(ensureNotBlank(baseUrl, "baseUrl"))
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
this.modelName = ensureNotBlank(modelName, "modelName");
this.temperature = temperature;
this.topP = topP;
this.maxTokens = maxTokens;
this.maxRetries = maxRetries;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private String baseUrl;
private String modelName;
private Double temperature;
private Double topP;
private Integer maxTokens;
private Duration timeout;
private Integer maxRetries;
private Boolean logRequests;
private Boolean logResponses;

public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}

public Builder temperature(Double temperature) {
this.temperature = temperature;
return this;
}

public Builder topP(Double topP) {
this.topP = topP;
return this;
}

public Builder maxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
return this;
}

public Builder timeout(Duration timeout) {
this.timeout = timeout;
return this;
}

public Builder maxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}

public Builder logRequests(Boolean logRequests) {
this.logRequests = logRequests;
return this;
}

public Builder logResponses(Boolean logResponses) {
this.logResponses = logResponses;
return this;
}

public LMStudioChatModel build() {
return new LMStudioChatModel(baseUrl, modelName, temperature, topP, maxTokens, timeout, maxRetries, logRequests, logResponses);
}
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
}

private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens);

if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.functions(toFunctions(toolSpecifications));
}
if (toolThatMustBeExecuted != null) {
requestBuilder.functionCall(toolThatMustBeExecuted.name());
}

ChatCompletionRequest request = requestBuilder.build();

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);

Usage usage = response.usage();

return Response.from(
aiMessageFrom(response),
new TokenUsage(usage.promptTokens(), usage.completionTokens()),
finishReasonFrom(response.choices().get(0).finishReason())
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.github.jeddict.ai.models;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import org.json.JSONArray;
import org.json.JSONObject;

/**
*
* @author Shiwani Gupta
*/
public class LMStudioModelFetcher {

private static final String API_URL = "http://localhost:1234/v1/models";

public String getAPIUrl() {
return API_URL;
}

public List<String> fetchModelNames(String apiUrl) {
List<String> modelIds = new ArrayList<>();

try {
URL url = new URL(apiUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setRequestProperty("Accept", "application/json");

int responseCode = connection.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
StringBuilder response = new StringBuilder();
String line;

while ((line = reader.readLine()) != null) {
response.append(line);
}
reader.close();

JSONObject jsonResponse = new JSONObject(response.toString());
JSONArray models = jsonResponse.getJSONArray("data");

for (int i = 0; i < models.length(); i++) {
JSONObject model = models.getJSONObject(i);
String id = model.getString("id");
modelIds.add(id);
}
} else {
System.err.println("GET request failed. Response Code: " + responseCode);
}

} catch (Exception e) {
e.printStackTrace();
}

return modelIds;
}

public static void main(String[] args) {
LMStudioModelFetcher fetcher = new LMStudioModelFetcher();
List<String> ids = fetcher.fetchModelNames(fetcher.getAPIUrl());
System.out.println("Model IDs: " + ids);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
/*
* Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license
* Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.github.jeddict.ai.settings;
package io.github.jeddict.ai.models;

import java.io.BufferedReader;
import java.io.InputStreamReader;
Expand Down
23 changes: 17 additions & 6 deletions src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package io.github.jeddict.ai.settings;

import io.github.jeddict.ai.models.OllamaModelFetcher;
import io.github.jeddict.ai.models.LMStudioModelFetcher;
import io.github.jeddict.ai.scanner.ProjectClassScanner;
import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -300,6 +302,12 @@ private void providerComboBoxActionPerformed(java.awt.event.ActionEvent evt) {//
providerLocationField.setText(fetcher.getAPIUrl());
providerLocationField.setVisible(true);
apiKeyField.setVisible(false);
} else if (selectedProvider == GenAIProvider.LM_STUDIO) {
providerKeyLabel.setText("Location:");
LMStudioModelFetcher fetcher = new LMStudioModelFetcher();
providerLocationField.setText(fetcher.getAPIUrl());
providerLocationField.setVisible(true);
apiKeyField.setVisible(false);
}
if (selectedProvider != null) {
updateModelComboBox(selectedProvider);
Expand All @@ -320,10 +328,11 @@ private List<String> getModelList(GenAIProvider selectedProvider) {
if (selectedProvider == GenAIProvider.OLLAMA
&& !providerLocationField.getText().isEmpty()) {
OllamaModelFetcher fetcher = new OllamaModelFetcher();
List<String> names = fetcher.fetchModelNames(providerLocationField.getText());
if (!names.isEmpty()) {
return names;
}
return fetcher.fetchModelNames(providerLocationField.getText());
} else if (selectedProvider == GenAIProvider.LM_STUDIO
&& !providerLocationField.getText().isEmpty()) {
LMStudioModelFetcher fetcher = new LMStudioModelFetcher();
return fetcher.fetchModelNames(providerLocationField.getText());
}
return MODELS.values().stream()
.filter(model -> model.getProvider().equals(selectedProvider))
Expand Down Expand Up @@ -359,7 +368,8 @@ void load() {
|| selectedProvider == GenAIProvider.OPEN_AI
|| selectedProvider == GenAIProvider.ANTHROPIC) {
apiKeyField.setText(preferencesManager.getApiKey(true));
} else if (selectedProvider == GenAIProvider.OLLAMA) {
} else if (selectedProvider == GenAIProvider.OLLAMA
|| selectedProvider == GenAIProvider.LM_STUDIO) {
providerLocationField.setText(preferencesManager.getProviderLocation());
}
}
Expand All @@ -378,7 +388,8 @@ void store() {
|| selectedProvider == GenAIProvider.OPEN_AI
|| selectedProvider == GenAIProvider.ANTHROPIC) {
preferencesManager.setApiKey(new String(apiKeyField.getPassword()));
} else if (selectedProvider == GenAIProvider.OLLAMA) {
} else if (selectedProvider == GenAIProvider.OLLAMA
|| selectedProvider == GenAIProvider.LM_STUDIO) {
preferencesManager.setProviderLocation(providerLocationField.getText());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public enum GenAIProvider {
OPEN_AI,
GOOGLE,
OLLAMA,
ANTHROPIC;
ANTHROPIC,
LM_STUDIO;

}

0 comments on commit 6bc63f1

Please sign in to comment.