use super::*; use crate::{ config::{Config, GlobalConfig, Input}, function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls}, render::render_stream, utils::*, }; use crate::vault::Vault; use anyhow::{Context, Result, bail}; use fancy_regex::Regex; use indexmap::IndexMap; use inquire::{ MultiSelect, Select, Text, list_option::ListOption, required, validator::Validation, }; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{Value, json}; use std::sync::LazyLock; use std::time::Duration; use tokio::sync::mpsc::unbounded_channel; pub const MODELS_YAML: &str = include_str!("../../models.yaml"); pub static ALL_PROVIDER_MODELS: LazyLock> = LazyLock::new(|| { Config::local_models_override() .ok() .unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap()) }); static EMBEDDING_MODEL_RE: LazyLock = LazyLock::new(|| { Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap() }); static ESCAPE_SLASH_RE: LazyLock = LazyLock::new(|| Regex::new(r"(? &GlobalConfig; fn extra_config(&self) -> Option<&ExtraConfig>; fn patch_config(&self) -> Option<&RequestPatch>; fn name(&self) -> &str; fn model(&self) -> &Model; fn build_client(&self) -> Result { let mut builder = ReqwestClient::builder(); let extra = self.extra_config(); let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10); if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) { builder = set_proxy(builder, proxy)?; } if let Some(user_agent) = self.global_config().read().user_agent.as_ref() { builder = builder.user_agent(user_agent); } let client = builder .connect_timeout(Duration::from_secs(timeout)) .build() .with_context(|| "Failed to build client")?; Ok(client) } async fn chat_completions(&self, input: Input) -> Result { if self.global_config().read().dry_run { let content = input.echo_messages(); return Ok(ChatCompletionsOutput::new(&content)); } let client = self.build_client()?; let data = input.prepare_completion_data(self.model(), false)?; self.chat_completions_inner(&client, data) .await .with_context(|| "Failed to call chat-completions api") } async fn chat_completions_streaming( &self, input: &Input, handler: &mut SseHandler, ) -> Result<()> { let abort_signal = handler.abort(); let input = input.clone(); tokio::select! { ret = async { if self.global_config().read().dry_run { let content = input.echo_messages(); handler.text(&content)?; return Ok(()); } let client = self.build_client()?; let data = input.prepare_completion_data(self.model(), true)?; self.chat_completions_streaming_inner(&client, handler, data).await } => { handler.done(); ret.with_context(|| "Failed to call chat-completions api") } _ = wait_abort_signal(&abort_signal) => { handler.done(); Ok(()) }, } } async fn embeddings(&self, data: &EmbeddingsData) -> Result>> { let client = self.build_client()?; self.embeddings_inner(&client, data) .await .context("Failed to call embeddings api") } async fn rerank(&self, data: &RerankData) -> Result { let client = self.build_client()?; self.rerank_inner(&client, data) .await .context("Failed to call rerank api") } async fn chat_completions_inner( &self, client: &ReqwestClient, data: ChatCompletionsData, ) -> Result; async fn chat_completions_streaming_inner( &self, client: &ReqwestClient, handler: &mut SseHandler, data: ChatCompletionsData, ) -> Result<()>; async fn embeddings_inner( &self, _client: &ReqwestClient, _data: &EmbeddingsData, ) -> Result { bail!("The client doesn't support embeddings api") } async fn rerank_inner( &self, _client: &ReqwestClient, _data: &RerankData, ) -> Result { bail!("The client doesn't support rerank api") } fn request_builder( &self, client: &reqwest::Client, mut request_data: RequestData, ) -> RequestBuilder { self.patch_request_data(&mut request_data); request_data.into_builder(client) } fn patch_request_data(&self, request_data: &mut RequestData) { let model_type = self.model().model_type(); if let Some(patch) = self.model().patch() { request_data.apply_patch(patch.clone()); } let patch_map = std::env::var(get_env_name(&format!( "patch_{}_{}", self.model().client_name(), model_type.api_name(), ))) .ok() .and_then(|v| serde_json::from_str(&v).ok()) .or_else(|| { self.patch_config() .and_then(|v| model_type.extract_patch(v)) .cloned() }); let patch_map = match patch_map { Some(v) => v, _ => return, }; for (key, patch) in patch_map { let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/"); if let Ok(regex) = Regex::new(&format!("^({key})$")) && let Ok(true) = regex.is_match(self.model().name()) { request_data.apply_patch(patch); return; } } } } impl Default for ClientConfig { fn default() -> Self { Self::OpenAIConfig(OpenAIConfig::default()) } } #[derive(Debug, Clone, Deserialize, Default)] pub struct ExtraConfig { pub proxy: Option, pub connect_timeout: Option, } #[derive(Debug, Clone, Deserialize, Default)] pub struct RequestPatch { pub chat_completions: Option, pub embeddings: Option, pub rerank: Option, } pub type ApiPatch = IndexMap; pub struct RequestData { pub url: String, pub headers: IndexMap, pub body: Value, } impl RequestData { pub fn new(url: T, body: Value) -> Self where T: std::fmt::Display, { Self { url: url.to_string(), headers: Default::default(), body, } } pub fn bearer_auth(&mut self, auth: T) where T: std::fmt::Display, { self.headers .insert("authorization".into(), format!("Bearer {auth}")); } pub fn header(&mut self, key: K, value: V) where K: std::fmt::Display, V: std::fmt::Display, { self.headers.insert(key.to_string(), value.to_string()); } pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder { let RequestData { url, headers, body } = self; debug!("Request {url} {body}"); let mut builder = client.post(url); for (key, value) in headers { builder = builder.header(key, value); } builder = builder.json(&body); builder } pub fn apply_patch(&mut self, patch: Value) { if let Some(patch_url) = patch["url"].as_str() { self.url = patch_url.into(); } if let Some(patch_body) = patch.get("body") { json_patch::merge(&mut self.body, patch_body) } if let Some(patch_headers) = patch["headers"].as_object() { for (key, value) in patch_headers { if let Some(value) = value.as_str() { self.header(key, value) } else if value.is_null() { self.headers.swap_remove(key); } } } } } #[derive(Debug)] pub struct ChatCompletionsData { pub messages: Vec, pub temperature: Option, pub top_p: Option, pub functions: Option>, pub stream: bool, } #[derive(Debug, Clone, Default)] pub struct ChatCompletionsOutput { pub text: String, pub tool_calls: Vec, } impl ChatCompletionsOutput { pub fn new(text: &str) -> Self { Self { text: text.to_string(), ..Default::default() } } } #[derive(Debug)] pub struct EmbeddingsData { pub texts: Vec, pub query: bool, } impl EmbeddingsData { pub fn new(texts: Vec, query: bool) -> Self { Self { texts, query } } } pub type EmbeddingsOutput = Vec>; #[derive(Debug)] pub struct RerankData { pub query: String, pub documents: Vec, pub top_n: usize, } impl RerankData { pub fn new(query: String, documents: Vec, top_n: usize) -> Self { Self { query, documents, top_n, } } } pub type RerankOutput = Vec; #[derive(Debug, Deserialize)] pub struct RerankResult { pub index: usize, } pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool); pub async fn create_config( prompts: &[PromptAction<'static>], client: &str, vault: &Vault, ) -> Result<(String, Value)> { let mut config = json!({ "type": client, }); for (key, desc, help_message, is_secret) in prompts { let env_name = format!("{client}_{key}").to_ascii_uppercase(); let required = std::env::var(&env_name).is_err(); let value = if !is_secret { prompt_input_string(desc, required, *help_message)? } else { vault.add_secret(&env_name)?; format!("{{{{{}}}}}", env_name) }; if !value.is_empty() { config[key] = value.into(); } } let model = set_client_models_config(&mut config, client).await?; let clients = json!(vec![config]); Ok((model, clients)) } pub async fn create_openai_compatible_client_config( client: &str, ) -> Result> { let api_base = OPENAI_COMPATIBLE_PROVIDERS .into_iter() .find(|(name, _)| client == *name) .map(|(_, api_base)| api_base) .unwrap_or("http(s)://{API_ADDR}/v1"); let name = if client == OpenAICompatibleClient::NAME { let value = prompt_input_string("Provider Name", true, None)?; value.replace(' ', "-") } else { client.to_string() }; let mut config = json!({ "type": OpenAICompatibleClient::NAME, "name": &name, }); let api_base = if api_base.contains('{') { prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))? } else { api_base.to_string() }; config["api_base"] = api_base.into(); let api_key = prompt_input_string("API Key", false, None)?; if !api_key.is_empty() { config["api_key"] = api_key.into(); } let model = set_client_models_config(&mut config, &name).await?; let clients = json!(vec![config]); Ok(Some((model, clients))) } pub async fn call_chat_completions( input: &Input, print: bool, extract_code: bool, client: &dyn Client, abort_signal: AbortSignal, ) -> Result<(String, Vec)> { let is_child_agent = client.global_config().read().current_depth > 0; let spinner_message = if is_child_agent { "" } else { "Generating" }; let ret = abortable_run_with_spinner( client.chat_completions(input.clone()), spinner_message, abort_signal, ) .await; match ret { Ok(ret) => { let ChatCompletionsOutput { mut text, tool_calls, .. } = ret; if !text.is_empty() { if extract_code { text = extract_code_block(&strip_think_tag(&text)).to_string(); } if print { client.global_config().read().print_markdown(&text)?; } } let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?; if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() { tool_results .iter() .for_each(|res| tracker.record_call(res.call.clone())); } Ok((text, tool_results)) } Err(err) => Err(err), } } pub async fn call_chat_completions_streaming( input: &Input, client: &dyn Client, abort_signal: AbortSignal, ) -> Result<(String, Vec)> { let (tx, rx) = unbounded_channel(); let mut handler = SseHandler::new(tx, abort_signal.clone()); let (send_ret, render_ret) = tokio::join!( client.chat_completions_streaming(input, &mut handler), render_stream(rx, client.global_config(), abort_signal.clone()), ); if handler.abort().aborted() { bail!("Aborted."); } render_ret?; let (text, tool_calls) = handler.take(); match send_ret { Ok(_) => { if !text.is_empty() && !text.ends_with('\n') { println!(); } let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?; if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() { tool_results .iter() .for_each(|res| tracker.record_call(res.call.clone())); } Ok((text, tool_results)) } Err(err) => { if !text.is_empty() { println!(); } Err(err) } } } pub fn noop_prepare_embeddings(_client: &T, _data: &EmbeddingsData) -> Result { bail!("The client doesn't support embeddings api") } pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result { bail!("The client doesn't support embeddings api") } pub fn noop_prepare_rerank(_client: &T, _data: &RerankData) -> Result { bail!("The client doesn't support rerank api") } pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result { bail!("The client doesn't support rerank api") } pub fn catch_error(data: &Value, status: u16) -> Result<()> { if (200..300).contains(&status) { return Ok(()); } debug!("Invalid response, status: {status}, data: {data}"); if let Some(error) = data["error"].as_object() { if let (Some(typ), Some(message)) = ( json_str_from_map(error, "type"), json_str_from_map(error, "message"), ) { bail!("{message} (type: {typ})"); } else if let (Some(typ), Some(message)) = ( json_str_from_map(error, "code"), json_str_from_map(error, "message"), ) { bail!("{message} (code: {typ})"); } } else if let Some(error) = data["errors"][0].as_object() { if let (Some(code), Some(message)) = ( error.get("code").and_then(|v| v.as_u64()), json_str_from_map(error, "message"), ) { bail!("{message} (status: {code})") } } else if let Some(error) = data[0]["error"].as_object() { if let (Some(status), Some(message)) = ( json_str_from_map(error, "status"), json_str_from_map(error, "message"), ) { bail!("{message} (status: {status})") } } else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64()) { bail!("{detail} (status: {status})"); } else if let Some(error) = data["error"].as_str() { bail!("{error}"); } else if let Some(message) = data["message"].as_str() { bail!("{message}"); } bail!("Invalid response data: {data} (status: {status})"); } pub fn json_str_from_map<'a>( map: &'a serde_json::Map, field_name: &str, ) -> Option<&'a str> { map.get(field_name).and_then(|v| v.as_str()) } async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result { if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) { let models: Vec = provider .models .iter() .filter(|v| v.model_type == "chat") .map(|v| v.name.clone()) .collect(); let model_name = select_model(models)?; return Ok(format!("{client}:{model_name}")); } let mut model_names = vec![]; if let (Some(true), Some(api_base), api_key) = ( client_config["type"] .as_str() .map(|v| v == OpenAICompatibleClient::NAME), client_config["api_base"].as_str(), client_config["api_key"] .as_str() .map(|v| v.to_string()) .or_else(|| { let env_name = format!("{client}_api_key").to_ascii_uppercase(); std::env::var(&env_name).ok() }), ) { match abortable_run_with_spinner( fetch_models(api_base, api_key.as_deref()), "Fetching models", create_abort_signal(), ) .await { Ok(fetched_models) => { model_names = MultiSelect::new("LLMs to include (required):", fetched_models) .with_validator(|list: &[ListOption<&String>]| { if list.is_empty() { Ok(Validation::Invalid( "At least one item must be selected".into(), )) } else { Ok(Validation::Valid) } }) .prompt()?; } Err(err) => { eprintln!("✗ Fetch models failed: {err}"); } } } if model_names.is_empty() { model_names = prompt_input_string( "LLMs to add", true, Some("Separated by commas, e.g. llama3.3,qwen2.5"), )? .split(',') .filter_map(|v| { let v = v.trim(); if v.is_empty() { None } else { Some(v.to_string()) } }) .collect::>(); } if model_names.is_empty() { bail!("No models"); } let models: Vec = model_names .iter() .map(|v| { let l = v.to_lowercase(); if l.contains("rank") { json!({ "name": v, "type": "reranker", }) } else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) { json!({ "name": v, "type": "embedding", "default_chunk_size": 1000, "max_batch_size": 100 }) } else if v.contains("vision") { json!({ "name": v, "supports_vision": true }) } else { json!({ "name": v, }) } }) .collect(); client_config["models"] = models.into(); let model_name = select_model(model_names)?; Ok(format!("{client}:{model_name}")) } fn select_model(model_names: Vec) -> Result { if model_names.is_empty() { bail!("No models"); } let model = if model_names.len() == 1 { model_names[0].clone() } else { Select::new("Default Model (required):", model_names).prompt()? }; Ok(model) } fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result { let desc = if required { format!("{desc} (required):") } else { format!("{desc} (optional):") }; let mut text = Text::new(&desc); if required { text = text.with_validator(required!("This field is required")) } if let Some(help_message) = help_message { text = text.with_help_message(help_message); } let text = text.prompt()?; Ok(text) }