Skip to content

Commit

Permalink
Merge pull request #20 from jeddict/dev
Browse files Browse the repository at this point in the history
DeepInfra support
  • Loading branch information
jShiwaniGupta authored Oct 17, 2024
2 parents c5dd8be + f0a73c0 commit a480aa4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
17 changes: 15 additions & 2 deletions src/main/java/io/github/jeddict/ai/JeddictChatModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import static io.github.jeddict.ai.settings.GenAIProvider.OLLAMA;
import io.github.jeddict.ai.models.LMStudioChatModel;
import static io.github.jeddict.ai.settings.GenAIProvider.LM_STUDIO;
import static io.github.jeddict.ai.settings.GenAIProvider.OPEN_AI;
import io.github.jeddict.ai.settings.PreferencesManager;
import static io.github.jeddict.ai.util.MimeUtil.MIME_TYPE_DESCRIPTIONS;
import static io.github.jeddict.ai.util.StringUtil.removeCodeBlockMarkers;
Expand Down Expand Up @@ -68,6 +69,11 @@ public JeddictChatModel() {
.apiKey(preferencesManager.getApiKey())
.modelName(preferencesManager.getModelName())
.build();
case DEEPINFRA -> model = OpenAiChatModel.builder()
.baseUrl("https://api.deepinfra.com/v1/openai")
.apiKey(preferencesManager.getApiKey())
.modelName(preferencesManager.getModelName())
.build();
case MISTRAL -> model = MistralAiChatModel.builder()
.apiKey(preferencesManager.getApiKey())
.modelName(preferencesManager.getModelName())
Expand Down Expand Up @@ -564,8 +570,15 @@ public List<Snippet> parseJsonToSnippets(String jsonResponse) {
}
List<Snippet> snippets = new ArrayList<>();

// Parse the JSON response
JSONArray jsonArray = new JSONArray(removeCodeBlockMarkers(jsonResponse));
JSONArray jsonArray;
try {
// Parse the JSON response
jsonArray = new JSONArray(removeCodeBlockMarkers(jsonResponse));
} catch (org.json.JSONException jsone) {
JSONObject jsonObject = new JSONObject(removeCodeBlockMarkers(jsonResponse));
jsonArray = new JSONArray();
jsonArray.put(jsonObject);
}

// Loop through each element in the JSON array
for (int i = 0; i < jsonArray.length(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ private void providerComboBoxActionPerformed(java.awt.event.ActionEvent evt) {//
GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem();
if (selectedProvider == GenAIProvider.GOOGLE
|| selectedProvider == GenAIProvider.OPEN_AI
|| selectedProvider == GenAIProvider.DEEPINFRA
|| selectedProvider == GenAIProvider.MISTRAL
|| selectedProvider == GenAIProvider.ANTHROPIC) {
providerKeyLabel.setText("API Key:");
Expand Down Expand Up @@ -378,6 +379,7 @@ void load() {
GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem();
if (selectedProvider == GenAIProvider.GOOGLE
|| selectedProvider == GenAIProvider.OPEN_AI
|| selectedProvider == GenAIProvider.DEEPINFRA
|| selectedProvider == GenAIProvider.MISTRAL
|| selectedProvider == GenAIProvider.ANTHROPIC) {
apiKeyField.setText(preferencesManager.getApiKey(true));
Expand All @@ -400,6 +402,7 @@ void store() {
GenAIProvider selectedProvider = (GenAIProvider) providerComboBox.getSelectedItem();
if (selectedProvider == GenAIProvider.GOOGLE
|| selectedProvider == GenAIProvider.OPEN_AI
|| selectedProvider == GenAIProvider.DEEPINFRA
|| selectedProvider == GenAIProvider.MISTRAL
|| selectedProvider == GenAIProvider.ANTHROPIC) {
preferencesManager.setApiKey(new String(apiKeyField.getPassword()));
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/io/github/jeddict/ai/settings/GenAIModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
package io.github.jeddict.ai.settings;

import static io.github.jeddict.ai.settings.GenAIProvider.ANTHROPIC;
import static io.github.jeddict.ai.settings.GenAIProvider.DEEPINFRA;
import static io.github.jeddict.ai.settings.GenAIProvider.GOOGLE;
import static io.github.jeddict.ai.settings.GenAIProvider.MISTRAL;
import static io.github.jeddict.ai.settings.GenAIProvider.OLLAMA;
import static io.github.jeddict.ai.settings.GenAIProvider.OPEN_AI;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -66,6 +66,12 @@ public class GenAIModel {
MODELS.put("mistral-embed", new GenAIModel(MISTRAL, "mistral-embed", "State-of-the-art semantic model for extracting text representations.", 0.10, 0.00)); // No output price provided
MODELS.put("ministral-3b-latest", new GenAIModel(MISTRAL, "ministral-3b-latest", "Most efficient edge model.", 0.04, 0.04));
MODELS.put("ministral-8b-latest", new GenAIModel(MISTRAL, "ministral-8b-latest", "Powerful model for on-device use cases.", 0.10, 0.10));

MODELS.put("meta-llama/Llama-3.2-3B-Instruct", new GenAIModel(DEEPINFRA, "meta-llama/Llama-3.2-3B-Instruct", "A 3B instruct model by Meta for instructional tasks.", 0.15, 0.45));
MODELS.put("Qwen/Qwen2.5-72B-Instruct", new GenAIModel(DEEPINFRA, "Qwen/Qwen2.5-72B-Instruct", "A large instruct model for various applications.", 0.20, 0.50));
MODELS.put("google/gemma-2-9b-it", new GenAIModel(DEEPINFRA, "google/gemma-2-9b-it", "Gemini model specialized for IT tasks, with a focus on performance.", 0.10, 0.30));
MODELS.put("microsoft/WizardLM-2-8x22B", new GenAIModel(DEEPINFRA, "microsoft/WizardLM-2-8x22B", "An 8x22B model designed for advanced conversational applications.", 0.25, 0.75));
MODELS.put("mistralai/Mistral-7B-Instruct-v0.3", new GenAIModel(DEEPINFRA, "mistralai/Mistral-7B-Instruct-v0.3", "A 7B instruct model optimized for general tasks.", 0.15, 0.45));
}

private final GenAIProvider provider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
public enum GenAIProvider {
OPEN_AI,
GOOGLE,
DEEPINFRA,
MISTRAL,
OLLAMA,
ANTHROPIC,
Expand Down

0 comments on commit a480aa4

Please sign in to comment.