Skip to content

Commit

Permalink
Merge pull request #16 from raff/main
Browse files Browse the repository at this point in the history
Add support for Ollama (via openai compatible APIs)
  • Loading branch information
baalimago authored Aug 19, 2024
2 parents 927aa81 + a2250e9 commit c61dd3f
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 1 deletion.
15 changes: 14 additions & 1 deletion internal/create_queriers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/vendors/anthropic"
"github.com/baalimago/clai/internal/vendors/mistral"
"github.com/baalimago/clai/internal/vendors/ollama"
"github.com/baalimago/clai/internal/vendors/openai"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
Expand Down Expand Up @@ -49,7 +50,19 @@ func CreateTextQuerier(conf text.Configurations) (models.Querier, error) {
q = &qTmp
}

if strings.Contains(conf.Model, "mistral") || strings.Contains(conf.Model, "mixtral") {
// process before mistral, in case we want to use mistral for ollama
if strings.HasPrefix(conf.Model, "ollama:") || conf.Model == "ollama" {
found = true
defaultCpy := ollama.OLLAMA_DEFAULT
if len(conf.Model) > 7 {
defaultCpy.Model = conf.Model[7:]
}
qTmp, err := text.NewQuerier(conf, &defaultCpy)
if err != nil {
return nil, fmt.Errorf("failed to create text querier: %w", err)
}
q = &qTmp
} else if strings.Contains(conf.Model, "mistral") || strings.Contains(conf.Model, "mixtral") {
found = true
defaultCpy := mistral.MINSTRAL_DEFAULT
defaultCpy.Model = conf.Model
Expand Down
7 changes: 7 additions & 0 deletions internal/text/querier_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ func vendorType(fromModel string) (string, string, string) {
if strings.Contains(fromModel, "claude") {
return "anthropic", "claude", fromModel
}
if strings.Contains(fromModel, "ollama") {
m := "llama3"
if strings.HasPrefix(fromModel, "ollama:") {
m = fromModel[7:]
}
return "ollama", m, fromModel
}
if strings.Contains(fromModel, "mistral") || strings.Contains(fromModel, "mixtral") {
return "mistral", "mistral", fromModel
}
Expand Down
56 changes: 56 additions & 0 deletions internal/vendors/ollama/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ollama

import "github.com/baalimago/clai/internal/tools"

// since we can use ollama in OpenAI compatible mode, we use the same types as `openai` package
type ChatCompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

type Choice struct {
Index int `json:"index"`
Delta Delta `json:"delta"`
Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{}
FinishReason string `json:"finish_reason"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

type Delta struct {
Content any `json:"content"`
Role string `json:"role"`
ToolCalls []ToolsCall `json:"tool_calls"`
}

type ToolsCall struct {
Function GptFunc `json:"function"`
ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"`
}

type GptFunc struct {
Arguments string `json:"arguments"`
Name string `json:"name"`
}

type GptTool struct {
Name string `json:"name"`
Description string `json:"description"`
Inputs tools.InputSchema `json:"parameters"`
}

type GptToolSuper struct {
Type string `json:"type"`
Function GptTool `json:"function"`
}
23 changes: 23 additions & 0 deletions internal/vendors/ollama/ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ollama

import (
"github.com/baalimago/clai/internal/text/generic"
)

var OLLAMA_DEFAULT = Ollama{
Model: "llama3",
Temperature: 1.0,
TopP: 1.0,
Url: ChatURL,
}

type Ollama struct {
generic.StreamCompleter
Model string `json:"model"`
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value
PresencePenalty float64 `json:"presence_penalty"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
Url string `json:"url"`
}
32 changes: 32 additions & 0 deletions internal/vendors/ollama/ollama_setup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package ollama

import (
"fmt"
"os"

"github.com/baalimago/clai/internal/tools"
)

const ChatURL = "http://localhost:11434/v1/chat/completions"

func (g *Ollama) Setup() error {
if os.Getenv("OLLAMA_API_KEY") == "" {
os.Setenv("OLLAMA_API_KEY", "ollama")
}
err := g.StreamCompleter.Setup("OLLAMA_API_KEY", ChatURL, "OLLAMA_DEBUG")
if err != nil {
return fmt.Errorf("failed to setup stream completer: %w", err)
}
g.StreamCompleter.Model = g.Model
g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty
g.StreamCompleter.MaxTokens = g.MaxTokens
g.StreamCompleter.Temperature = &g.Temperature
g.StreamCompleter.TopP = &g.TopP
toolChoice := "auto"
g.StreamCompleter.ToolChoice = &toolChoice
return nil
}

func (g *Ollama) RegisterTool(tool tools.AiTool) {
g.StreamCompleter.InternalRegisterTool(tool)
}

0 comments on commit c61dd3f

Please sign in to comment.