From 0496d6741059bd445c87259e8fe324edbf96c7c1 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 18 Oct 2024 01:41:26 +0530 Subject: [PATCH] Groq support --- .../github/jeddict/ai/JeddictChatModel.java | 5 ++ .../jeddict/ai/models/GroqModelFetcher.java | 80 +++++++++++++++++++ .../ai/settings/AIAssistancePanel.form | 3 + .../ai/settings/AIAssistancePanel.java | 26 +++++- .../jeddict/ai/settings/GenAIProvider.java | 1 + 5 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 src/main/java/io/github/jeddict/ai/models/GroqModelFetcher.java diff --git a/src/main/java/io/github/jeddict/ai/JeddictChatModel.java b/src/main/java/io/github/jeddict/ai/JeddictChatModel.java index 0e9b1f6..1cffe77 100644 --- a/src/main/java/io/github/jeddict/ai/JeddictChatModel.java +++ b/src/main/java/io/github/jeddict/ai/JeddictChatModel.java @@ -75,6 +75,11 @@ public JeddictChatModel() { .apiKey(preferencesManager.getApiKey()) .modelName(preferencesManager.getModelName()) .build(); + case GROQ -> model = OpenAiChatModel.builder() + .baseUrl(preferencesManager.getProviderLocation()) + .apiKey(preferencesManager.getApiKey()) + .modelName(preferencesManager.getModelName()) + .build(); case CUSTOM_OPEN_AI -> model = OpenAiChatModel.builder() .baseUrl(preferencesManager.getProviderLocation()) .apiKey(preferencesManager.getApiKey()) diff --git a/src/main/java/io/github/jeddict/ai/models/GroqModelFetcher.java b/src/main/java/io/github/jeddict/ai/models/GroqModelFetcher.java new file mode 100644 index 0000000..6979fb7 --- /dev/null +++ b/src/main/java/io/github/jeddict/ai/models/GroqModelFetcher.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.github.jeddict.ai.models; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.json.JSONArray; +import org.json.JSONObject; +import org.openide.util.Exceptions; + +/** + * + * @author Shiwani Gupta + */ +public class GroqModelFetcher { + + public static final String API_URL = "https://api.groq.com/openai/v1"; + + public String getAPIUrl() { + return API_URL; + } + + public static List fetchModels(String baseUrl, String token) { + if (baseUrl == null || baseUrl.isEmpty() || token == null || token.isEmpty()) { + return Collections.EMPTY_LIST; + } + try { + HttpClient client = HttpClient.newHttpClient(); + + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/models")) + .header("Authorization", "Bearer " + token) + .header("Content-Type", "application/json") + .GET() + .build(); + + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() == 200) { + JSONObject jsonResponse = new JSONObject(response.body()); + JSONArray modelsArray = jsonResponse.getJSONArray("data"); // Fetch data array + + List modelList = new ArrayList<>(); + for (int i = 0; i < modelsArray.length(); i++) { + JSONObject modelObject = modelsArray.getJSONObject(i); + String modelId = modelObject.getString("id"); // Get the "id" of the model + modelList.add(modelId); + } + return modelList; + } + } catch (IOException ex) { + Exceptions.printStackTrace(ex); + } catch (InterruptedException ex) { + Exceptions.printStackTrace(ex); + } + return Collections.EMPTY_LIST; + } +} diff --git a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.form b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.form index 1409dae..b08ba24 100644 --- a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.form +++ b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.form @@ -218,6 +218,9 @@ + + + diff --git a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java index 0932f53..78f5e76 100644 --- a/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java +++ b/src/main/java/io/github/jeddict/ai/settings/AIAssistancePanel.java @@ -20,6 +20,7 @@ import static io.github.jeddict.ai.models.Constant.DEEPINFRA_URL; import io.github.jeddict.ai.models.GPT4AllModelFetcher; +import io.github.jeddict.ai.models.GroqModelFetcher; import io.github.jeddict.ai.models.OllamaModelFetcher; import io.github.jeddict.ai.models.LMStudioModelFetcher; import io.github.jeddict.ai.scanner.ProjectClassScanner; @@ -157,6 +158,11 @@ public void actionPerformed(java.awt.event.ActionEvent evt) { apiKeyPane.add(jLayeredPane2); apiKeyField.setText(org.openide.util.NbBundle.getMessage(AIAssistancePanel.class, "AIAssistancePanel.apiKeyField.text")); // NOI18N + apiKeyField.addFocusListener(new java.awt.event.FocusAdapter() { + public void focusLost(java.awt.event.FocusEvent evt) { + apiKeyFieldFocusLost(evt); + } + }); apiKeyPane.add(apiKeyField); providerPane.add(apiKeyPane); @@ -304,6 +310,7 @@ private void showDescriptionCheckBoxActionPerformed(java.awt.event.ActionEvent e private void providerComboBoxActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_providerComboBoxActionPerformed GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem(); if (selectedProvider == GenAIProvider.DEEPINFRA + || selectedProvider == GenAIProvider.GROQ || selectedProvider == GenAIProvider.CUSTOM_OPEN_AI) { apiKeyLabel.setVisible(true); apiKeyField.setVisible(true); @@ -311,6 +318,8 @@ private void providerComboBoxActionPerformed(java.awt.event.ActionEvent evt) {// providerLocationLabel.setVisible(true); if (selectedProvider == GenAIProvider.DEEPINFRA) { providerLocationField.setText(DEEPINFRA_URL); + } else if (selectedProvider == GenAIProvider.GROQ) { + providerLocationField.setText(new GroqModelFetcher().getAPIUrl()); } else { providerLocationField.setText(""); } @@ -404,6 +413,13 @@ public void mouseExited(MouseEvent e) { } }//GEN-LAST:event_providerComboBoxActionPerformed + private void apiKeyFieldFocusLost(java.awt.event.FocusEvent evt) {//GEN-FIRST:event_apiKeyFieldFocusLost + GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem(); + if (selectedProvider != null) { + updateModelComboBox(selectedProvider); + } + }//GEN-LAST:event_apiKeyFieldFocusLost + private void updateModelComboBox(GenAIProvider selectedProvider) { modelComboBox.removeAllItems(); for (String model : getModelList(selectedProvider)) { @@ -427,6 +443,10 @@ private List getModelList(GenAIProvider selectedProvider) { && !providerLocationField.getText().isEmpty()) { GPT4AllModelFetcher fetcher = new GPT4AllModelFetcher(); return fetcher.fetchModelNames(providerLocationField.getText()); + }else if (selectedProvider == GenAIProvider.GROQ + && !providerLocationField.getText().isEmpty()) { + GroqModelFetcher fetcher = new GroqModelFetcher(); + return fetcher.fetchModels(providerLocationField.getText(), new String(apiKeyField.getPassword())); } return MODELS.values().stream() .filter(model -> model.getProvider().equals(selectedProvider)) @@ -459,7 +479,8 @@ void load() { showDescriptionCheckBox.setSelected(preferencesManager.isDescriptionEnabled()); GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem(); if (selectedProvider == GenAIProvider.CUSTOM_OPEN_AI - || selectedProvider == GenAIProvider.DEEPINFRA) { + || selectedProvider == GenAIProvider.DEEPINFRA + || selectedProvider == GenAIProvider.GROQ) { apiKeyField.setText(preferencesManager.getApiKey(true)); providerLocationField.setText(preferencesManager.getProviderLocation()); } else if (selectedProvider == GenAIProvider.GOOGLE @@ -485,7 +506,8 @@ void store() { GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem(); if (selectedProvider == GenAIProvider.CUSTOM_OPEN_AI - || selectedProvider == GenAIProvider.DEEPINFRA) { + || selectedProvider == GenAIProvider.DEEPINFRA + || selectedProvider == GenAIProvider.GROQ) { preferencesManager.setApiKey(new String(apiKeyField.getPassword())); preferencesManager.setProviderLocation(providerLocationField.getText()); } else if (selectedProvider == GenAIProvider.GOOGLE diff --git a/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java b/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java index cf6d977..ee0859a 100644 --- a/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java +++ b/src/main/java/io/github/jeddict/ai/settings/GenAIProvider.java @@ -10,6 +10,7 @@ public enum GenAIProvider { CUSTOM_OPEN_AI("", ""), GOOGLE("https://ai.google.dev/gemini-api/docs/models/gemini", "https://console.cloud.google.com/apis/credentials"), DEEPINFRA("https://deepinfra.com/models", "https://deepinfra.com/dash/api_keys"), + GROQ("https://console.groq.com/docs/models", "https://console.groq.com/keys"), MISTRAL("https://docs.mistral.ai/getting-started/models/models_overview/", "https://console.mistral.ai/api-keys/"), OLLAMA("https://ollama.com/models", ""), ANTHROPIC("https://docs.anthropic.com/en/docs/about-claude/models", "https://console.anthropic.com/settings/keys"),