Skip to content

Commit

Permalink
feat: add model configuration to config file
Browse files Browse the repository at this point in the history
  • Loading branch information
arcuru committed Mar 21, 2024
1 parent 07fc18c commit 65bb090
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use matrix_sdk::{
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use regex::Regex;
use serde::Deserialize;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use std::sync::Mutex;
use std::{collections::HashMap, fs::File};
use tokio::time::{sleep, Duration};

#[derive(Parser)]
Expand All @@ -33,6 +33,19 @@ struct Config {
password: String,
/// Allow list of which accounts we will respond to
allow_list: Option<String>,
ollama: Option<HashMap<String, OllamaConfig>>,
}

#[derive(Debug, Deserialize, Clone)]
struct OllamaConfig {
model: String,
endpoint: Option<EndpointConfig>,
}

#[derive(Debug, Deserialize, Clone)]
struct EndpointConfig {
host: String,
port: u16,
}

lazy_static! {
Expand Down Expand Up @@ -299,12 +312,29 @@ async fn get_context(room: &Room) -> Result<String, ()> {

// Send the current conversation to the configured ollama server
async fn send_to_ollama_server(input: String) -> Result<String, ()> {
let ollama = Ollama::default();
let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
if config.ollama.is_none() {
return Err(());
}
let ollama = config.ollama.unwrap();
if ollama.is_empty() {
return Err(());
}

let server = ollama.values().next().unwrap();

// Just pull the first thing we see
let ollama_server = if let Some(endpoint) = &server.endpoint {
Ollama::new(endpoint.host.clone(), endpoint.port)
} else {
Ollama::default()
};

let model = "llama2:latest".to_string();
let prompt = input;

let res = ollama.generate(GenerationRequest::new(model, prompt)).await;
let res = ollama_server
.generate(GenerationRequest::new(server.model.clone(), prompt))
.await;

if let Ok(res) = res {
// Strip leading and trailing whitespace from res.response
Expand Down

0 comments on commit 65bb090

Please sign in to comment.