diff --git a/src/main.rs b/src/main.rs index 6fde96d..e24ad69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,10 +12,10 @@ use matrix_sdk::{ use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; use regex::Regex; use serde::Deserialize; -use std::fs::File; use std::io::Read; use std::path::PathBuf; use std::sync::Mutex; +use std::{collections::HashMap, fs::File}; use tokio::time::{sleep, Duration}; #[derive(Parser)] @@ -33,6 +33,19 @@ struct Config { password: String, /// Allow list of which accounts we will respond to allow_list: Option, + ollama: Option>, +} + +#[derive(Debug, Deserialize, Clone)] +struct OllamaConfig { + model: String, + endpoint: Option, +} + +#[derive(Debug, Deserialize, Clone)] +struct EndpointConfig { + host: String, + port: u16, } lazy_static! { @@ -299,12 +312,29 @@ async fn get_context(room: &Room) -> Result { // Send the current conversation to the configured ollama server async fn send_to_ollama_server(input: String) -> Result { - let ollama = Ollama::default(); + let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap(); + if config.ollama.is_none() { + return Err(()); + } + let ollama = config.ollama.unwrap(); + if ollama.is_empty() { + return Err(()); + } + + let server = ollama.values().next().unwrap(); + + // Just pull the first thing we see + let ollama_server = if let Some(endpoint) = &server.endpoint { + Ollama::new(endpoint.host.clone(), endpoint.port) + } else { + Ollama::default() + }; - let model = "llama2:latest".to_string(); let prompt = input; - let res = ollama.generate(GenerationRequest::new(model, prompt)).await; + let res = ollama_server + .generate(GenerationRequest::new(server.model.clone(), prompt)) + .await; if let Ok(res) = res { // Strip leading and trailing whitespace from res.response