From 650dbd92e0c48f446e66b49a1025c33c8e25cf1b Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Tue, 7 Oct 2025 10:45:42 -0600 Subject: [PATCH] Baseline project --- src/cli.rs | 219 +++ src/client/access_token.rs | 32 + src/client/azure_openai.rs | 82 + src/client/bedrock.rs | 643 +++++++ src/client/claude.rs | 353 ++++ src/client/cohere.rs | 255 +++ src/client/common.rs | 678 +++++++ src/client/gemini.rs | 136 ++ src/client/macros.rs | 245 +++ src/client/message.rs | 235 +++ src/client/mod.rs | 62 + src/client/model.rs | 407 +++++ src/client/openai.rs | 408 +++++ src/client/openai_compatible.rs | 162 ++ src/client/stream.rs | 296 +++ src/client/vertexai.rs | 537 ++++++ src/config/agent.rs | 570 ++++++ src/config/input.rs | 545 ++++++ src/config/mod.rs | 3034 +++++++++++++++++++++++++++++++ src/config/role.rs | 416 +++++ src/config/session.rs | 659 +++++++ src/function.rs | 825 +++++++++ src/main.rs | 496 +++++ src/mcp/mod.rs | 290 +++ src/parsers/bash.rs | 149 ++ src/parsers/mod.rs | 2 + src/parsers/python.rs | 420 +++++ src/rag/mod.rs | 1013 +++++++++++ src/rag/serde_vectors.rs | 66 + src/rag/splitter/language.rs | 235 +++ src/rag/splitter/mod.rs | 475 +++++ src/render/markdown.rs | 393 ++++ src/render/mod.rs | 30 + src/render/stream.rs | 217 +++ src/repl/completer.rs | 159 ++ src/repl/highlighter.rs | 49 + src/repl/mod.rs | 1014 +++++++++++ src/repl/prompt.rs | 51 + src/serve.rs | 935 ++++++++++ src/utils/abort_signal.rs | 88 + src/utils/clipboard.rs | 49 + src/utils/command.rs | 242 +++ src/utils/crypto.rs | 35 + src/utils/html_to_md.rs | 18 + src/utils/input.rs | 47 + src/utils/loader.rs | 125 ++ src/utils/logs.rs | 63 + src/utils/mod.rs | 252 +++ src/utils/native.rs | 46 + src/utils/path.rs | 356 ++++ src/utils/render_prompt.rs | 155 ++ src/utils/request.rs | 464 +++++ src/utils/spinner.rs | 217 +++ src/utils/variables.rs | 32 + 54 files changed, 18982 insertions(+) create mode 100644 src/cli.rs create mode 100644 src/client/access_token.rs create mode 100644 src/client/azure_openai.rs create mode 100644 src/client/bedrock.rs create mode 100644 src/client/claude.rs create mode 100644 src/client/cohere.rs create mode 100644 src/client/common.rs create mode 100644 src/client/gemini.rs create mode 100644 src/client/macros.rs create mode 100644 src/client/message.rs create mode 100644 src/client/mod.rs create mode 100644 src/client/model.rs create mode 100644 src/client/openai.rs create mode 100644 src/client/openai_compatible.rs create mode 100644 src/client/stream.rs create mode 100644 src/client/vertexai.rs create mode 100644 src/config/agent.rs create mode 100644 src/config/input.rs create mode 100644 src/config/mod.rs create mode 100644 src/config/role.rs create mode 100644 src/config/session.rs create mode 100644 src/function.rs create mode 100644 src/main.rs create mode 100644 src/mcp/mod.rs create mode 100644 src/parsers/bash.rs create mode 100644 src/parsers/mod.rs create mode 100644 src/parsers/python.rs create mode 100644 src/rag/mod.rs create mode 100644 src/rag/serde_vectors.rs create mode 100644 src/rag/splitter/language.rs create mode 100644 src/rag/splitter/mod.rs create mode 100644 src/render/markdown.rs create mode 100644 src/render/mod.rs create mode 100644 src/render/stream.rs create mode 100644 src/repl/completer.rs create mode 100644 src/repl/highlighter.rs create mode 100644 src/repl/mod.rs create mode 100644 src/repl/prompt.rs create mode 100644 src/serve.rs create mode 100644 src/utils/abort_signal.rs create mode 100644 src/utils/clipboard.rs create mode 100644 src/utils/command.rs create mode 100644 src/utils/crypto.rs create mode 100644 src/utils/html_to_md.rs create mode 100644 src/utils/input.rs create mode 100644 src/utils/loader.rs create mode 100644 src/utils/logs.rs create mode 100644 src/utils/mod.rs create mode 100644 src/utils/native.rs create mode 100644 src/utils/path.rs create mode 100644 src/utils/render_prompt.rs create mode 100644 src/utils/request.rs create mode 100644 src/utils/spinner.rs create mode 100644 src/utils/variables.rs diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..299d742 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,219 @@ +use crate::client::{list_models, ModelType}; +use crate::config::{list_agents, Config}; +use anyhow::{Context, Result}; +use clap::ValueHint; +use clap::{crate_authors, crate_description, crate_name, crate_version, Parser}; +use clap_complete::ArgValueCompleter; +use clap_complete::CompletionCandidate; +use is_terminal::IsTerminal; +use std::ffi::OsStr; +use std::io::{stdin, Read}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +#[command( + name = crate_name!(), + author = crate_authors!(), + version = crate_version!(), + about = crate_description!(), + help_template = "\ +{before-help}{name} {version} +{author-with-newline} +{about-with-newline} +{usage-heading} {usage} + +{all-args}{after-help} +" +)] +pub struct Cli { + /// Select a LLM model + #[arg(short, long, add = ArgValueCompleter::new(model_completer))] + pub model: Option, + /// Use the system prompt + #[arg(long)] + pub prompt: Option, + /// Select a role + #[arg(short, long, add = ArgValueCompleter::new(role_completer))] + pub role: Option, + /// Start or join a session + #[arg(short = 's', long, add = ArgValueCompleter::new(session_completer))] + pub session: Option>, + /// Ensure the session is empty + #[arg(long)] + pub empty_session: bool, + /// Ensure the new conversation is saved to the session + #[arg(long)] + pub save_session: bool, + /// Start an agent + #[arg(short = 'a', long, add = ArgValueCompleter::new(agent_completer))] + pub agent: Option, + /// Set agent variables + #[arg(long, value_names = ["NAME", "VALUE"], num_args = 2)] + pub agent_variable: Vec, + /// Start a RAG + #[arg(long, add = ArgValueCompleter::new(rag_completer))] + pub rag: Option, + /// Rebuild the RAG to sync document changes + #[arg(long)] + pub rebuild_rag: bool, + /// Execute a macro + #[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))] + pub macro_name: Option, + /// Serve the LLM API and WebAPP + #[arg(long, value_name = "PORT|IP|IP:PORT")] + pub serve: Option>, + /// Execute commands in natural language + #[arg(short = 'e', long)] + pub execute: bool, + /// Output code only + #[arg(short = 'c', long)] + pub code: bool, + /// Include files, directories, or URLs + #[arg(short = 'f', long, value_name = "FILE|URL", value_hint = ValueHint::AnyPath)] + pub file: Vec, + /// Turn off stream mode + #[arg(short = 'S', long)] + pub no_stream: bool, + /// Display the message without sending it + #[arg(long)] + pub dry_run: bool, + /// Display information + #[arg(long)] + pub info: bool, + /// Build all configured Bash tool scripts + #[arg(long)] + pub build_tools: bool, + /// Sync models updates + #[arg(long)] + pub sync_models: bool, + /// List all available chat models + #[arg(long)] + pub list_models: bool, + /// List all roles + #[arg(long)] + pub list_roles: bool, + /// List all sessions + #[arg(long)] + pub list_sessions: bool, + /// List all agents + #[arg(long)] + pub list_agents: bool, + /// List all RAGs + #[arg(long)] + pub list_rags: bool, + /// List all macros + #[arg(long)] + pub list_macros: bool, + /// Input text + #[arg(trailing_var_arg = true)] + text: Vec, + /// Tail logs + #[arg(long)] + pub tail_logs: bool, + /// Disable colored log output + #[arg(long, requires = "tail_logs")] + pub disable_log_colors: bool, +} + +impl Cli { + pub fn text(&self) -> Result> { + let mut stdin_text = String::new(); + if !stdin().is_terminal() { + let _ = stdin() + .read_to_string(&mut stdin_text) + .context("Invalid stdin pipe")?; + }; + match self.text.is_empty() { + true => { + if stdin_text.is_empty() { + Ok(None) + } else { + Ok(Some(stdin_text)) + } + } + false => { + if self.macro_name.is_some() { + let text = self + .text + .iter() + .map(|v| shell_words::quote(v)) + .collect::>() + .join(" "); + if stdin_text.is_empty() { + Ok(Some(text)) + } else { + Ok(Some(format!("{text} -- {stdin_text}"))) + } + } else { + let text = self.text.join(" "); + if stdin_text.is_empty() { + Ok(Some(text)) + } else { + Ok(Some(format!("{text}\n{stdin_text}"))) + } + } + } + } + } +} + +fn model_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + match Config::init_bare() { + Ok(config) => list_models(&config, ModelType::Chat) + .into_iter() + .filter(|&m| m.id().starts_with(&*cur)) + .map(|m| CompletionCandidate::new(m.id())) + .collect(), + Err(_) => vec![], + } +} + +fn role_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + Config::list_roles(true) + .into_iter() + .filter(|r| r.starts_with(&*cur)) + .map(CompletionCandidate::new) + .collect() +} + +fn agent_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + list_agents() + .into_iter() + .filter(|a| a.starts_with(&*cur)) + .map(CompletionCandidate::new) + .collect() +} + +fn rag_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + Config::list_rags() + .into_iter() + .filter(|r| r.starts_with(&*cur)) + .map(CompletionCandidate::new) + .collect() +} + +fn macro_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + Config::list_macros() + .into_iter() + .filter(|m| m.starts_with(&*cur)) + .map(CompletionCandidate::new) + .collect() +} + +fn session_completer(current: &OsStr) -> Vec { + let cur = current.to_string_lossy(); + match Config::init_bare() { + Ok(config) => config + .list_sessions() + .into_iter() + .filter(|s| s.starts_with(&*cur)) + .map(CompletionCandidate::new) + .collect(), + Err(_) => vec![], + } +} diff --git a/src/client/access_token.rs b/src/client/access_token.rs new file mode 100644 index 0000000..e09e02c --- /dev/null +++ b/src/client/access_token.rs @@ -0,0 +1,32 @@ +use anyhow::{anyhow, Result}; +use chrono::Utc; +use indexmap::IndexMap; +use parking_lot::RwLock; +use std::sync::LazyLock; + +static ACCESS_TOKENS: LazyLock>> = + LazyLock::new(|| RwLock::new(IndexMap::new())); + +pub fn get_access_token(client_name: &str) -> Result { + ACCESS_TOKENS + .read() + .get(client_name) + .map(|(token, _)| token.clone()) + .ok_or_else(|| anyhow!("Invalid access token")) +} + +pub fn is_valid_access_token(client_name: &str) -> bool { + let access_tokens = ACCESS_TOKENS.read(); + let (token, expires_at) = match access_tokens.get(client_name) { + Some(v) => v, + None => return false, + }; + !token.is_empty() && Utc::now().timestamp() < *expires_at +} + +pub fn set_access_token(client_name: &str, token: String, expires_at: i64) { + let mut access_tokens = ACCESS_TOKENS.write(); + let entry = access_tokens.entry(client_name.to_string()).or_default(); + entry.0 = token; + entry.1 = expires_at; +} diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs new file mode 100644 index 0000000..a46b818 --- /dev/null +++ b/src/client/azure_openai.rs @@ -0,0 +1,82 @@ +use super::openai::*; +use super::*; + +use anyhow::Result; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct AzureOpenAIConfig { + pub name: Option, + pub api_base: Option, + pub api_key: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl AzureOpenAIClient { + config_get_fn!(api_base, get_api_base); + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptAction<'static>; 2] = [ + ( + "api_base", + "API Base", + Some("e.g. https://{RESOURCE}.openai.azure.com"), + ), + ("api_key", "API Key", None), + ]; +} + +impl_client_trait!( + AzureOpenAIClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (noop_prepare_rerank, noop_rerank), +); + +fn prepare_chat_completions( + self_: &AzureOpenAIClient, + data: ChatCompletionsData, +) -> Result { + let api_base = self_.get_api_base()?; + let api_key = self_.get_api_key()?; + + let url = format!( + "{}/openai/deployments/{}/chat/completions?api-version=2024-12-01-preview", + &api_base, + self_.model.real_name() + ); + + let body = openai_build_chat_completions_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + request_data.header("api-key", api_key); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result { + let api_base = self_.get_api_base()?; + let api_key = self_.get_api_key()?; + + let url = format!( + "{}/openai/deployments/{}/embeddings?api-version=2024-10-21", + &api_base, + self_.model.real_name() + ); + + let body = openai_build_embeddings_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + request_data.header("api-key", api_key); + + Ok(request_data) +} diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs new file mode 100644 index 0000000..88d57d9 --- /dev/null +++ b/src/client/bedrock.rs @@ -0,0 +1,643 @@ +use super::*; + +use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag}; + +use anyhow::{bail, Context, Result}; +use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; +use aws_smithy_eventstream::smithy::parse_response_headers; +use bytes::BytesMut; +use chrono::{DateTime, Utc}; +use futures_util::StreamExt; +use indexmap::IndexMap; +use reqwest::{Client as ReqwestClient, Method, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Clone, Deserialize)] +pub struct BedrockConfig { + pub name: Option, + pub access_key_id: Option, + pub secret_access_key: Option, + pub region: Option, + pub session_token: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl BedrockClient { + config_get_fn!(access_key_id, get_access_key_id); + config_get_fn!(secret_access_key, get_secret_access_key); + config_get_fn!(region, get_region); + config_get_fn!(session_token, get_session_token); + + pub const PROMPTS: [PromptAction<'static>; 3] = [ + ("access_key_id", "AWS Access Key ID", None), + ("secret_access_key", "AWS Secret Access Key", None), + ("region", "AWS Region", None), + ]; + + fn chat_completions_builder( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + let access_key_id = self.get_access_key_id()?; + let secret_access_key = self.get_secret_access_key()?; + let region = self.get_region()?; + let session_token = self.get_session_token().ok(); + let host = format!("bedrock-runtime.{region}.amazonaws.com"); + + let model_name = &self.model.real_name(); + + let uri = if data.stream { + format!("/model/{model_name}/converse-stream") + } else { + format!("/model/{model_name}/converse") + }; + + let body = build_chat_completions_body(data, &self.model)?; + + let mut request_data = RequestData::new("", body); + self.patch_request_data(&mut request_data); + let RequestData { + url: _, + headers, + body, + } = request_data; + + let builder = aws_fetch( + client, + &AwsCredentials { + access_key_id, + secret_access_key, + region, + session_token, + }, + AwsRequest { + method: Method::POST, + host, + service: "bedrock".into(), + uri, + querystring: "".into(), + headers, + body: body.to_string(), + }, + )?; + + Ok(builder) + } + + fn embeddings_builder( + &self, + client: &ReqwestClient, + data: &EmbeddingsData, + ) -> Result { + let access_key_id = self.get_access_key_id()?; + let secret_access_key = self.get_secret_access_key()?; + let region = self.get_region()?; + let session_token = self.get_session_token().ok(); + let host = format!("bedrock-runtime.{region}.amazonaws.com"); + + let uri = format!("/model/{}/invoke", self.model.real_name()); + + let input_type = match data.query { + true => "search_query", + false => "search_document", + }; + + let body = json!({ + "texts": data.texts, + "input_type": input_type, + }); + + let mut request_data = RequestData::new("", body); + self.patch_request_data(&mut request_data); + let RequestData { + url: _, + headers, + body, + } = request_data; + + let builder = aws_fetch( + client, + &AwsCredentials { + access_key_id, + secret_access_key, + region, + session_token, + }, + AwsRequest { + method: Method::POST, + host, + service: "bedrock".into(), + uri, + querystring: "".into(), + headers, + body: body.to_string(), + }, + )?; + + Ok(builder) + } +} + +#[async_trait::async_trait] +impl Client for BedrockClient { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + let builder = self.chat_completions_builder(client, data)?; + chat_completions(builder).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()> { + let builder = self.chat_completions_builder(client, data)?; + chat_completions_streaming(builder, handler).await + } + + async fn embeddings_inner( + &self, + client: &ReqwestClient, + data: &EmbeddingsData, + ) -> Result { + let builder = self.embeddings_builder(client, data)?; + embeddings(builder).await + } +} + +async fn chat_completions(builder: RequestBuilder) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + + debug!("non-stream-data: {data}"); + extract_chat_completions(&data) +} + +async fn chat_completions_streaming( + builder: RequestBuilder, + handler: &mut SseHandler, +) -> Result<()> { + let res = builder.send().await?; + let status = res.status(); + if !status.is_success() { + let data: Value = res.json().await?; + catch_error(&data, status.as_u16())?; + bail!("Invalid response data: {data}"); + } + + let mut function_name = String::new(); + let mut function_arguments = String::new(); + let mut function_id = String::new(); + let mut reasoning_state = 0; + + let mut stream = res.bytes_stream(); + let mut buffer = BytesMut::new(); + let mut decoder = MessageFrameDecoder::new(); + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + buffer.extend_from_slice(&chunk); + while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? { + let response_headers = parse_response_headers(&message)?; + let message_type = response_headers.message_type.as_str(); + let smithy_type = response_headers.smithy_type.as_str(); + match (message_type, smithy_type) { + ("event", _) => { + let data: Value = serde_json::from_slice(message.payload())?; + debug!("stream-data: {smithy_type} {data}"); + match smithy_type { + "contentBlockStart" => { + if let Some(tool_use) = data["start"]["toolUse"].as_object() { + if let (Some(id), Some(name)) = ( + json_str_from_map(tool_use, "toolUseId"), + json_str_from_map(tool_use, "name"), + ) { + if !function_name.is_empty() { + if function_arguments.is_empty() { + function_arguments = String::from("{}"); + } + let arguments: Value = + function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; + } + function_arguments.clear(); + function_name = name.into(); + function_id = id.into(); + } + } + } + "contentBlockDelta" => { + if let Some(text) = data["delta"]["text"].as_str() { + handler.text(text)?; + } else if let Some(text) = + data["delta"]["reasoningContent"]["text"].as_str() + { + if reasoning_state == 0 { + handler.text("\n")?; + reasoning_state = 1; + } + handler.text(text)?; + } else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() { + function_arguments.push_str(input); + } + } + "contentBlockStop" => { + if reasoning_state == 1 { + handler.text("\n\n\n")?; + reasoning_state = 0; + } + if !function_name.is_empty() { + if function_arguments.is_empty() { + function_arguments = String::from("{}"); + } + let arguments: Value = function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; + } + } + _ => {} + } + } + ("exception", _) => { + let payload = base64_decode(message.payload())?; + let data = String::from_utf8_lossy(&payload); + + bail!("Invalid response data: {data} (smithy_type: {smithy_type})") + } + _ => { + bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",); + } + } + } + } + Ok(()) +} + +async fn embeddings(builder: RequestBuilder) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + Ok(res_body.embeddings) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + embeddings: Vec>, +} + +fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { + let ChatCompletionsData { + mut messages, + temperature, + top_p, + functions, + stream: _, + } = data; + + let system_message = extract_system_message(&mut messages); + + let mut network_image_urls = vec![]; + + let messages_len = messages.len(); + let messages: Vec = messages + .into_iter() + .enumerate() + .flat_map(|(i, message)| { + let Message { role, content } = message; + match content { + MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => { + vec![json!({ "role": role, "content": [ { "text": strip_think_tag(&text) } ] })] + } + MessageContent::Text(text) => vec![json!({ + "role": role, + "content": [ + { + "text": text, + } + ], + })], + MessageContent::Array(list) => { + let content: Vec<_> = list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => { + json!({"text": text}) + } + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + if let Some((mime_type, data)) = url + .strip_prefix("data:") + .and_then(|v| v.split_once(";base64,")) + { + json!({ + "image": { + "format": mime_type.replace("image/", ""), + "source": { + "bytes": data, + } + } + }) + } else { + network_image_urls.push(url.clone()); + json!({ "url": url }) + } + } + }) + .collect(); + vec![json!({ + "role": role, + "content": content, + })] + } + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut assistant_parts = vec![]; + let mut user_parts = vec![]; + if !text.is_empty() { + assistant_parts.push(json!({ + "text": text, + })) + } + for tool_result in tool_results { + assistant_parts.push(json!({ + "toolUse": { + "toolUseId": tool_result.call.id, + "name": tool_result.call.name, + "input": tool_result.call.arguments, + } + })); + user_parts.push(json!({ + "toolResult": { + "toolUseId": tool_result.call.id, + "content": [ + { + "json": tool_result.output, + } + ] + } + })); + } + vec![ + json!({ + "role": "assistant", + "content": assistant_parts, + }), + json!({ + "role": "user", + "content": user_parts, + }), + ] + } + } + }) + .collect(); + + if !network_image_urls.is_empty() { + bail!( + "The model does not support network images: {:?}", + network_image_urls + ); + } + + let mut body = json!({ + "inferenceConfig": {}, + "messages": messages, + }); + if let Some(v) = system_message { + body["system"] = json!([ + { + "text": v, + } + ]) + } + + if let Some(v) = model.max_tokens_param() { + body["inferenceConfig"]["maxTokens"] = v.into(); + } + if let Some(v) = temperature { + body["inferenceConfig"]["temperature"] = v.into(); + } + if let Some(v) = top_p { + body["inferenceConfig"]["topP"] = v.into(); + } + if let Some(functions) = functions { + let tools: Vec<_> = functions + .iter() + .map(|v| { + json!({ + "toolSpec": { + "name": v.name, + "description": v.description, + "inputSchema": { + "json": v.parameters, + }, + } + }) + }) + .collect(); + body["toolConfig"] = json!({ + "tools": tools, + }) + } + Ok(body) +} + +fn extract_chat_completions(data: &Value) -> Result { + let mut text = String::new(); + let mut reasoning = None; + let mut tool_calls = vec![]; + if let Some(array) = data["output"]["message"]["content"].as_array() { + for item in array { + if let Some(v) = item["text"].as_str() { + if !text.is_empty() { + text.push_str("\n\n"); + } + text.push_str(v); + } else if let Some(reasoning_text) = + item["reasoningContent"]["reasoningText"].as_object() + { + if let Some(text) = json_str_from_map(reasoning_text, "text") { + reasoning = Some(text.to_string()); + } + } else if let Some(tool_use) = item["toolUse"].as_object() { + if let (Some(id), Some(name), Some(input)) = ( + json_str_from_map(tool_use, "toolUseId"), + json_str_from_map(tool_use, "name"), + tool_use.get("input"), + ) { + tool_calls.push(ToolCall::new( + name.to_string(), + input.clone(), + Some(id.to_string()), + )) + } + } + } + } + + if let Some(reasoning) = reasoning { + text = format!("\n{reasoning}\n\n\n{text}") + } + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + + let output = ChatCompletionsOutput { + text, + tool_calls, + id: None, + input_tokens: data["usage"]["inputTokens"].as_u64(), + output_tokens: data["usage"]["outputTokens"].as_u64(), + }; + Ok(output) +} + +#[derive(Debug)] +struct AwsCredentials { + access_key_id: String, + secret_access_key: String, + region: String, + session_token: Option, +} + +#[derive(Debug)] +struct AwsRequest { + method: Method, + host: String, + service: String, + uri: String, + querystring: String, + headers: IndexMap, + body: String, +} + +fn aws_fetch( + client: &ReqwestClient, + credentials: &AwsCredentials, + request: AwsRequest, +) -> Result { + let AwsRequest { + method, + host, + service, + uri, + querystring, + mut headers, + body, + } = request; + let region = &credentials.region; + + let endpoint = format!("https://{host}{uri}"); + + let now: DateTime = Utc::now(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + let date_stamp = amz_date[0..8].to_string(); + headers.insert("host".into(), host.clone()); + headers.insert("x-amz-date".into(), amz_date.clone()); + if let Some(token) = credentials.session_token.clone() { + headers.insert("x-amz-security-token".into(), token); + } + + let canonical_headers = headers + .iter() + .map(|(key, value)| format!("{key}:{value}\n")) + .collect::>() + .join(""); + + let signed_headers = headers + .iter() + .map(|(key, _)| key.as_str()) + .collect::>() + .join(";"); + + let payload_hash = sha256(&body); + + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method, + encode_uri(&uri), + querystring, + canonical_headers, + signed_headers, + payload_hash + ); + + let algorithm = "AWS4-HMAC-SHA256"; + let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request"); + let string_to_sign = format!( + "{}\n{}\n{}\n{}", + algorithm, + amz_date, + credential_scope, + sha256(&canonical_request) + ); + + let signing_key = gen_signing_key( + &credentials.secret_access_key, + &date_stamp, + region, + &service, + ); + let signature = hmac_sha256(&signing_key, &string_to_sign); + let signature = hex_encode(&signature); + + let authorization_header = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + algorithm, credentials.access_key_id, credential_scope, signed_headers, signature + ); + + headers.insert("authorization".into(), authorization_header); + + debug!("Request {endpoint} {body}"); + + let mut request_builder = client.request(method, endpoint).body(body); + + for (key, value) in &headers { + request_builder = request_builder.header(key, value); + } + + Ok(request_builder) +} + +fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec { + let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp); + let k_region = hmac_sha256(&k_date, region); + let k_service = hmac_sha256(&k_region, service); + hmac_sha256(&k_service, "aws4_request") +} diff --git a/src/client/claude.rs b/src/client/claude.rs new file mode 100644 index 0000000..29aa6b0 --- /dev/null +++ b/src/client/claude.rs @@ -0,0 +1,353 @@ +use super::*; + +use crate::utils::strip_think_tag; + +use anyhow::{bail, Context, Result}; +use reqwest::RequestBuilder; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_BASE: &str = "https://api.anthropic.com/v1"; + +#[derive(Debug, Clone, Deserialize)] +pub struct ClaudeConfig { + pub name: Option, + pub api_key: Option, + pub api_base: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl ClaudeClient { + config_get_fn!(api_key, get_api_key); + config_get_fn!(api_base, get_api_base); + + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; +} + +impl_client_trait!( + ClaudeClient, + ( + prepare_chat_completions, + claude_chat_completions, + claude_chat_completions_streaming + ), + (noop_prepare_embeddings, noop_embeddings), + (noop_prepare_rerank, noop_rerank), +); + +fn prepare_chat_completions( + self_: &ClaudeClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{}/messages", api_base.trim_end_matches('/')); + let body = claude_build_chat_completions_body(data, &self_.model)?; + + let mut request_data = RequestData::new(url, body); + + request_data.header("anthropic-version", "2023-06-01"); + request_data.header("x-api-key", api_key); + + Ok(request_data) +} + +pub async fn claude_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + debug!("non-stream-data: {data}"); + claude_extract_chat_completions(&data) +} + +pub async fn claude_chat_completions_streaming( + builder: RequestBuilder, + handler: &mut SseHandler, + _model: &Model, +) -> Result<()> { + let mut function_name = String::new(); + let mut function_arguments = String::new(); + let mut function_id = String::new(); + let mut reasoning_state = 0; + let handle = |message: SseMessage| -> Result { + let data: Value = serde_json::from_str(&message.data)?; + debug!("stream-data: {data}"); + if let Some(typ) = data["type"].as_str() { + match typ { + "content_block_start" => { + if let (Some("tool_use"), Some(name), Some(id)) = ( + data["content_block"]["type"].as_str(), + data["content_block"]["name"].as_str(), + data["content_block"]["id"].as_str(), + ) { + if !function_name.is_empty() { + let arguments: Value = + function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; + } + function_name = name.into(); + function_arguments.clear(); + function_id = id.into(); + } + } + "content_block_delta" => { + if let Some(text) = data["delta"]["text"].as_str() { + handler.text(text)?; + } else if let Some(text) = data["delta"]["thinking"].as_str() { + if reasoning_state == 0 { + handler.text("\n")?; + reasoning_state = 1; + } + handler.text(text)?; + } else if let (true, Some(partial_json)) = ( + !function_name.is_empty(), + data["delta"]["partial_json"].as_str(), + ) { + function_arguments.push_str(partial_json); + } + } + "content_block_stop" => { + if reasoning_state == 1 { + handler.text("\n\n\n")?; + reasoning_state = 0; + } + if !function_name.is_empty() { + let arguments: Value = if function_arguments.is_empty() { + json!({}) + } else { + function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })? + }; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; + } + } + _ => {} + } + } + Ok(false) + }; + + sse_stream(builder, handle).await +} + +pub fn claude_build_chat_completions_body( + data: ChatCompletionsData, + model: &Model, +) -> Result { + let ChatCompletionsData { + mut messages, + temperature, + top_p, + functions, + stream, + } = data; + + let system_message = extract_system_message(&mut messages); + + let mut network_image_urls = vec![]; + + let messages_len = messages.len(); + let messages: Vec = messages + .into_iter() + .enumerate() + .flat_map(|(i, message)| { + let Message { role, content } = message; + match content { + MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => { + vec![json!({ "role": role, "content": strip_think_tag(&text) })] + } + MessageContent::Text(text) => vec![json!({ + "role": role, + "content": text, + })], + MessageContent::Array(list) => { + let content: Vec<_> = list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => { + json!({"type": "text", "text": text}) + } + MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + } => { + if let Some((mime_type, data)) = url + .strip_prefix("data:") + .and_then(|v| v.split_once(";base64,")) + { + json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": data, + } + }) + } else { + network_image_urls.push(url.clone()); + json!({ "url": url }) + } + } + }) + .collect(); + vec![json!({ + "role": role, + "content": content, + })] + } + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut assistant_parts = vec![]; + let mut user_parts = vec![]; + if !text.is_empty() { + assistant_parts.push(json!({ + "type": "text", + "text": text, + })) + } + for tool_result in tool_results { + assistant_parts.push(json!({ + "type": "tool_use", + "id": tool_result.call.id, + "name": tool_result.call.name, + "input": tool_result.call.arguments, + })); + user_parts.push(json!({ + "type": "tool_result", + "tool_use_id": tool_result.call.id, + "content": tool_result.output.to_string(), + })); + } + vec![ + json!({ + "role": "assistant", + "content": assistant_parts, + }), + json!({ + "role": "user", + "content": user_parts, + }), + ] + } + } + }) + .collect(); + + if !network_image_urls.is_empty() { + bail!( + "The model does not support network images: {:?}", + network_image_urls + ); + } + + let mut body = json!({ + "model": model.real_name(), + "messages": messages, + }); + if let Some(v) = system_message { + body["system"] = v.into(); + } + if let Some(v) = model.max_tokens_param() { + body["max_tokens"] = v.into(); + } + if let Some(v) = temperature { + body["temperature"] = v.into(); + } + if let Some(v) = top_p { + body["top_p"] = v.into(); + } + if stream { + body["stream"] = true.into(); + } + if let Some(functions) = functions { + body["tools"] = functions + .iter() + .map(|v| { + json!({ + "name": v.name, + "description": v.description, + "input_schema": v.parameters, + }) + }) + .collect(); + } + Ok(body) +} + +pub fn claude_extract_chat_completions(data: &Value) -> Result { + let mut text = String::new(); + let mut reasoning = None; + let mut tool_calls = vec![]; + if let Some(list) = data["content"].as_array() { + for item in list { + match item["type"].as_str() { + Some("thinking") => { + if let Some(v) = item["thinking"].as_str() { + reasoning = Some(v.to_string()); + } + } + Some("text") => { + if let Some(v) = item["text"].as_str() { + if !text.is_empty() { + text.push_str("\n\n"); + } + text.push_str(v); + } + } + Some("tool_use") => { + if let (Some(name), Some(input), Some(id)) = ( + item["name"].as_str(), + item.get("input"), + item["id"].as_str(), + ) { + tool_calls.push(ToolCall::new( + name.to_string(), + input.clone(), + Some(id.to_string()), + )); + } + } + _ => {} + } + } + } + if let Some(reasoning) = reasoning { + text = format!("\n{reasoning}\n\n\n{text}") + } + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + + let output = ChatCompletionsOutput { + text: text.to_string(), + tool_calls, + id: data["id"].as_str().map(|v| v.to_string()), + input_tokens: data["usage"]["input_tokens"].as_u64(), + output_tokens: data["usage"]["output_tokens"].as_u64(), + }; + Ok(output) +} diff --git a/src/client/cohere.rs b/src/client/cohere.rs new file mode 100644 index 0000000..9dc6da9 --- /dev/null +++ b/src/client/cohere.rs @@ -0,0 +1,255 @@ +use super::openai::*; +use super::openai_compatible::*; +use super::*; + +use anyhow::{bail, Context, Result}; +use reqwest::RequestBuilder; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_BASE: &str = "https://api.cohere.ai/v2"; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct CohereConfig { + pub name: Option, + pub api_key: Option, + pub api_base: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl CohereClient { + config_get_fn!(api_key, get_api_key); + config_get_fn!(api_base, get_api_base); + + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; +} + +impl_client_trait!( + CohereClient, + ( + prepare_chat_completions, + chat_completions, + chat_completions_streaming + ), + (prepare_embeddings, embeddings), + (prepare_rerank, generic_rerank), +); + +fn prepare_chat_completions( + self_: &CohereClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{}/chat", api_base.trim_end_matches('/')); + let mut body = openai_build_chat_completions_body(data, &self_.model); + if let Some(obj) = body.as_object_mut() { + if let Some(top_p) = obj.remove("top_p") { + obj.insert("p".to_string(), top_p); + } + } + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{}/embed", api_base.trim_end_matches('/')); + + let input_type = match data.query { + true => "search_query", + false => "search_document", + }; + + let body = json!({ + "model": self_.model.real_name(), + "texts": data.texts, + "input_type": input_type, + "embedding_types": ["float"], + }); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + + Ok(request_data) +} + +fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{}/rerank", api_base.trim_end_matches('/')); + let body = generic_build_rerank_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + + Ok(request_data) +} + +async fn chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + + debug!("non-stream-data: {data}"); + extract_chat_completions(&data) +} + +async fn chat_completions_streaming( + builder: RequestBuilder, + handler: &mut SseHandler, + _model: &Model, +) -> Result<()> { + let mut function_name = String::new(); + let mut function_arguments = String::new(); + let mut function_id = String::new(); + let handle = |message: SseMessage| -> Result { + if message.data == "[DONE]" { + return Ok(true); + } + let data: Value = serde_json::from_str(&message.data)?; + debug!("stream-data: {data}"); + if let Some(typ) = data["type"].as_str() { + match typ { + "content-delta" => { + if let Some(text) = data["delta"]["message"]["content"]["text"].as_str() { + handler.text(text)?; + } + } + "tool-plan-delta" => { + if let Some(text) = data["delta"]["message"]["tool_plan"].as_str() { + handler.text(text)?; + } + } + "tool-call-start" => { + if let (Some(function), Some(id)) = ( + data["delta"]["message"]["tool_calls"]["function"].as_object(), + data["delta"]["message"]["tool_calls"]["id"].as_str(), + ) { + if let Some(name) = function.get("name").and_then(|v| v.as_str()) { + function_name = name.to_string(); + } + function_id = id.to_string(); + } + } + "tool-call-delta" => { + if let Some(text) = + data["delta"]["message"]["tool_calls"]["function"]["arguments"].as_str() + { + function_arguments.push_str(text); + } + } + "tool-call-end" => { + if !function_name.is_empty() { + let arguments: Value = function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + Some(function_id.clone()), + ))?; + } + function_name.clear(); + function_arguments.clear(); + function_id.clear(); + } + _ => {} + } + } + Ok(false) + }; + + sse_stream(builder, handle).await +} + +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + Ok(res_body.embeddings.float) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + embeddings: EmbeddingsResBodyEmbeddings, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyEmbeddings { + float: Vec>, +} + +fn extract_chat_completions(data: &Value) -> Result { + let mut text = data["message"]["content"][0]["text"] + .as_str() + .unwrap_or_default() + .to_string(); + + let mut tool_calls = vec![]; + if let Some(calls) = data["message"]["tool_calls"].as_array() { + if text.is_empty() { + if let Some(tool_plain) = data["message"]["tool_plan"].as_str() { + text = tool_plain.to_string(); + } + } + for call in calls { + if let (Some(name), Some(arguments), Some(id)) = ( + call["function"]["name"].as_str(), + call["function"]["arguments"].as_str(), + call["id"].as_str(), + ) { + let arguments: Value = arguments.parse().with_context(|| { + format!("Tool call '{name}' have non-JSON arguments '{arguments}'") + })?; + tool_calls.push(ToolCall::new( + name.to_string(), + arguments, + Some(id.to_string()), + )); + } + } + } + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + let output = ChatCompletionsOutput { + text, + tool_calls, + id: data["id"].as_str().map(|v| v.to_string()), + input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(), + output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(), + }; + Ok(output) +} diff --git a/src/client/common.rs b/src/client/common.rs new file mode 100644 index 0000000..c7ac96f --- /dev/null +++ b/src/client/common.rs @@ -0,0 +1,678 @@ +use super::*; + +use crate::{ + config::{Config, GlobalConfig, Input}, + function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult}, + render::render_stream, + utils::*, +}; + +use anyhow::{bail, Context, Result}; +use fancy_regex::Regex; +use indexmap::IndexMap; +use inquire::{ + list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text, +}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::sync::LazyLock; +use std::time::Duration; +use tokio::sync::mpsc::unbounded_channel; + +const MODELS_YAML: &str = include_str!("../../models.yaml"); + +pub static ALL_PROVIDER_MODELS: LazyLock> = LazyLock::new(|| { + Config::local_models_override() + .ok() + .unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap()) +}); + +static EMBEDDING_MODEL_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap() +}); + +static ESCAPE_SLASH_RE: LazyLock = LazyLock::new(|| Regex::new(r"(? &GlobalConfig; + + fn extra_config(&self) -> Option<&ExtraConfig>; + + fn patch_config(&self) -> Option<&RequestPatch>; + + fn name(&self) -> &str; + + fn model(&self) -> &Model; + + fn model_mut(&mut self) -> &mut Model; + + fn build_client(&self) -> Result { + let mut builder = ReqwestClient::builder(); + let extra = self.extra_config(); + let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10); + if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) { + builder = set_proxy(builder, proxy)?; + } + if let Some(user_agent) = self.global_config().read().user_agent.as_ref() { + builder = builder.user_agent(user_agent); + } + let client = builder + .connect_timeout(Duration::from_secs(timeout)) + .build() + .with_context(|| "Failed to build client")?; + Ok(client) + } + + async fn chat_completions(&self, input: Input) -> Result { + if self.global_config().read().dry_run { + let content = input.echo_messages(); + return Ok(ChatCompletionsOutput::new(&content)); + } + let client = self.build_client()?; + let data = input.prepare_completion_data(self.model(), false)?; + self.chat_completions_inner(&client, data) + .await + .with_context(|| "Failed to call chat-completions api") + } + + async fn chat_completions_streaming( + &self, + input: &Input, + handler: &mut SseHandler, + ) -> Result<()> { + let abort_signal = handler.abort(); + let input = input.clone(); + tokio::select! { + ret = async { + if self.global_config().read().dry_run { + let content = input.echo_messages(); + handler.text(&content)?; + return Ok(()); + } + let client = self.build_client()?; + let data = input.prepare_completion_data(self.model(), true)?; + self.chat_completions_streaming_inner(&client, handler, data).await + } => { + handler.done(); + ret.with_context(|| "Failed to call chat-completions api") + } + _ = wait_abort_signal(&abort_signal) => { + handler.done(); + Ok(()) + }, + } + } + + async fn embeddings(&self, data: &EmbeddingsData) -> Result>> { + let client = self.build_client()?; + self.embeddings_inner(&client, data) + .await + .context("Failed to call embeddings api") + } + + async fn rerank(&self, data: &RerankData) -> Result { + let client = self.build_client()?; + self.rerank_inner(&client, data) + .await + .context("Failed to call rerank api") + } + + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result; + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()>; + + async fn embeddings_inner( + &self, + _client: &ReqwestClient, + _data: &EmbeddingsData, + ) -> Result { + bail!("The client doesn't support embeddings api") + } + + async fn rerank_inner( + &self, + _client: &ReqwestClient, + _data: &RerankData, + ) -> Result { + bail!("The client doesn't support rerank api") + } + + fn request_builder( + &self, + client: &reqwest::Client, + mut request_data: RequestData, + ) -> RequestBuilder { + self.patch_request_data(&mut request_data); + request_data.into_builder(client) + } + + fn patch_request_data(&self, request_data: &mut RequestData) { + let model_type = self.model().model_type(); + if let Some(patch) = self.model().patch() { + request_data.apply_patch(patch.clone()); + } + + let patch_map = std::env::var(get_env_name(&format!( + "patch_{}_{}", + self.model().client_name(), + model_type.api_name(), + ))) + .ok() + .and_then(|v| serde_json::from_str(&v).ok()) + .or_else(|| { + self.patch_config() + .and_then(|v| model_type.extract_patch(v)) + .cloned() + }); + let patch_map = match patch_map { + Some(v) => v, + _ => return, + }; + for (key, patch) in patch_map { + let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/"); + if let Ok(regex) = Regex::new(&format!("^({key})$")) { + if let Ok(true) = regex.is_match(self.model().name()) { + request_data.apply_patch(patch); + return; + } + } + } + } +} + +impl Default for ClientConfig { + fn default() -> Self { + Self::OpenAIConfig(OpenAIConfig::default()) + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct ExtraConfig { + pub proxy: Option, + pub connect_timeout: Option, +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct RequestPatch { + pub chat_completions: Option, + pub embeddings: Option, + pub rerank: Option, +} + +pub type ApiPatch = IndexMap; + +pub struct RequestData { + pub url: String, + pub headers: IndexMap, + pub body: Value, +} + +impl RequestData { + pub fn new(url: T, body: Value) -> Self + where + T: std::fmt::Display, + { + Self { + url: url.to_string(), + headers: Default::default(), + body, + } + } + + pub fn bearer_auth(&mut self, auth: T) + where + T: std::fmt::Display, + { + self.headers + .insert("authorization".into(), format!("Bearer {auth}")); + } + + pub fn header(&mut self, key: K, value: V) + where + K: std::fmt::Display, + V: std::fmt::Display, + { + self.headers.insert(key.to_string(), value.to_string()); + } + + pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder { + let RequestData { url, headers, body } = self; + debug!("Request {url} {body}"); + + let mut builder = client.post(url); + for (key, value) in headers { + builder = builder.header(key, value); + } + builder = builder.json(&body); + builder + } + + pub fn apply_patch(&mut self, patch: Value) { + if let Some(patch_url) = patch["url"].as_str() { + self.url = patch_url.into(); + } + if let Some(patch_body) = patch.get("body") { + json_patch::merge(&mut self.body, patch_body) + } + if let Some(patch_headers) = patch["headers"].as_object() { + for (key, value) in patch_headers { + if let Some(value) = value.as_str() { + self.header(key, value) + } else if value.is_null() { + self.headers.swap_remove(key); + } + } + } + } +} + +#[derive(Debug)] +pub struct ChatCompletionsData { + pub messages: Vec, + pub temperature: Option, + pub top_p: Option, + pub functions: Option>, + pub stream: bool, +} + +#[derive(Debug, Clone, Default)] +pub struct ChatCompletionsOutput { + pub text: String, + pub tool_calls: Vec, + pub id: Option, + pub input_tokens: Option, + pub output_tokens: Option, +} + +impl ChatCompletionsOutput { + pub fn new(text: &str) -> Self { + Self { + text: text.to_string(), + ..Default::default() + } + } +} + +#[derive(Debug)] +pub struct EmbeddingsData { + pub texts: Vec, + pub query: bool, +} + +impl EmbeddingsData { + pub fn new(texts: Vec, query: bool) -> Self { + Self { texts, query } + } +} + +pub type EmbeddingsOutput = Vec>; + +#[derive(Debug)] +pub struct RerankData { + pub query: String, + pub documents: Vec, + pub top_n: usize, +} + +impl RerankData { + pub fn new(query: String, documents: Vec, top_n: usize) -> Self { + Self { + query, + documents, + top_n, + } + } +} + +pub type RerankOutput = Vec; + +#[derive(Debug, Deserialize)] +pub struct RerankResult { + pub index: usize, + pub relevance_score: f64, +} + +pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>); + +pub async fn create_config( + prompts: &[PromptAction<'static>], + client: &str, +) -> Result<(String, Value)> { + let mut config = json!({ + "type": client, + }); + for (key, desc, help_message) in prompts { + let env_name = format!("{client}_{key}").to_ascii_uppercase(); + let required = std::env::var(&env_name).is_err(); + let value = prompt_input_string(desc, required, *help_message)?; + if !value.is_empty() { + config[key] = value.into(); + } + } + let model = set_client_models_config(&mut config, client).await?; + let clients = json!(vec![config]); + Ok((model, clients)) +} + +pub async fn create_openai_compatible_client_config( + client: &str, +) -> Result> { + let api_base = OPENAI_COMPATIBLE_PROVIDERS + .into_iter() + .find(|(name, _)| client == *name) + .map(|(_, api_base)| api_base) + .unwrap_or("http(s)://{API_ADDR}/v1"); + + let name = if client == OpenAICompatibleClient::NAME { + let value = prompt_input_string("Provider Name", true, None)?; + value.replace(' ', "-") + } else { + client.to_string() + }; + + let mut config = json!({ + "type": OpenAICompatibleClient::NAME, + "name": &name, + }); + + let api_base = if api_base.contains('{') { + prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))? + } else { + api_base.to_string() + }; + config["api_base"] = api_base.into(); + + let api_key = prompt_input_string("API Key", false, None)?; + if !api_key.is_empty() { + config["api_key"] = api_key.into(); + } + + let model = set_client_models_config(&mut config, &name).await?; + let clients = json!(vec![config]); + Ok(Some((model, clients))) +} + +pub async fn call_chat_completions( + input: &Input, + print: bool, + extract_code: bool, + client: &dyn Client, + abort_signal: AbortSignal, +) -> Result<(String, Vec)> { + let ret = abortable_run_with_spinner( + client.chat_completions(input.clone()), + "Generating", + abort_signal, + ) + .await; + + match ret { + Ok(ret) => { + let ChatCompletionsOutput { + mut text, + tool_calls, + .. + } = ret; + if !text.is_empty() { + if extract_code { + text = extract_code_block(&strip_think_tag(&text)).to_string(); + } + if print { + client.global_config().read().print_markdown(&text)?; + } + } + Ok(( + text, + eval_tool_calls(client.global_config(), tool_calls).await?, + )) + } + Err(err) => Err(err), + } +} + +pub async fn call_chat_completions_streaming( + input: &Input, + client: &dyn Client, + abort_signal: AbortSignal, +) -> Result<(String, Vec)> { + let (tx, rx) = unbounded_channel(); + let mut handler = SseHandler::new(tx, abort_signal.clone()); + + let (send_ret, render_ret) = tokio::join!( + client.chat_completions_streaming(input, &mut handler), + render_stream(rx, client.global_config(), abort_signal.clone()), + ); + + if handler.abort().aborted() { + bail!("Aborted."); + } + + render_ret?; + + let (text, tool_calls) = handler.take(); + match send_ret { + Ok(_) => { + if !text.is_empty() && !text.ends_with('\n') { + println!(); + } + Ok(( + text, + eval_tool_calls(client.global_config(), tool_calls).await?, + )) + } + Err(err) => { + if !text.is_empty() { + println!(); + } + Err(err) + } + } +} + +pub fn noop_prepare_embeddings(_client: &T, _data: &EmbeddingsData) -> Result { + bail!("The client doesn't support embeddings api") +} + +pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result { + bail!("The client doesn't support embeddings api") +} + +pub fn noop_prepare_rerank(_client: &T, _data: &RerankData) -> Result { + bail!("The client doesn't support rerank api") +} + +pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result { + bail!("The client doesn't support rerank api") +} + +pub fn catch_error(data: &Value, status: u16) -> Result<()> { + if (200..300).contains(&status) { + return Ok(()); + } + debug!("Invalid response, status: {status}, data: {data}"); + if let Some(error) = data["error"].as_object() { + if let (Some(typ), Some(message)) = ( + json_str_from_map(error, "type"), + json_str_from_map(error, "message"), + ) { + bail!("{message} (type: {typ})"); + } else if let (Some(typ), Some(message)) = ( + json_str_from_map(error, "code"), + json_str_from_map(error, "message"), + ) { + bail!("{message} (code: {typ})"); + } + } else if let Some(error) = data["errors"][0].as_object() { + if let (Some(code), Some(message)) = ( + error.get("code").and_then(|v| v.as_u64()), + json_str_from_map(error, "message"), + ) { + bail!("{message} (status: {code})") + } + } else if let Some(error) = data[0]["error"].as_object() { + if let (Some(status), Some(message)) = ( + json_str_from_map(error, "status"), + json_str_from_map(error, "message"), + ) { + bail!("{message} (status: {status})") + } + } else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64()) + { + bail!("{detail} (status: {status})"); + } else if let Some(error) = data["error"].as_str() { + bail!("{error}"); + } else if let Some(message) = data["message"].as_str() { + bail!("{message}"); + } + bail!("Invalid response data: {data} (status: {status})"); +} + +pub fn json_str_from_map<'a>( + map: &'a serde_json::Map, + field_name: &str, +) -> Option<&'a str> { + map.get(field_name).and_then(|v| v.as_str()) +} + +async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result { + if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) { + let models: Vec = provider + .models + .iter() + .filter(|v| v.model_type == "chat") + .map(|v| v.name.clone()) + .collect(); + let model_name = select_model(models)?; + return Ok(format!("{client}:{model_name}")); + } + let mut model_names = vec![]; + if let (Some(true), Some(api_base), api_key) = ( + client_config["type"] + .as_str() + .map(|v| v == OpenAICompatibleClient::NAME), + client_config["api_base"].as_str(), + client_config["api_key"] + .as_str() + .map(|v| v.to_string()) + .or_else(|| { + let env_name = format!("{client}_api_key").to_ascii_uppercase(); + std::env::var(&env_name).ok() + }), + ) { + match abortable_run_with_spinner( + fetch_models(api_base, api_key.as_deref()), + "Fetching models", + create_abort_signal(), + ) + .await + { + Ok(fetched_models) => { + model_names = MultiSelect::new("LLMs to include (required):", fetched_models) + .with_validator(|list: &[ListOption<&String>]| { + if list.is_empty() { + Ok(Validation::Invalid( + "At least one item must be selected".into(), + )) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + } + Err(err) => { + eprintln!("✗ Fetch models failed: {err}"); + } + } + } + if model_names.is_empty() { + model_names = prompt_input_string( + "LLMs to add", + true, + Some("Separated by commas, e.g. llama3.3,qwen2.5"), + )? + .split(',') + .filter_map(|v| { + let v = v.trim(); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + }) + .collect::>(); + } + if model_names.is_empty() { + bail!("No models"); + } + let models: Vec = model_names + .iter() + .map(|v| { + let l = v.to_lowercase(); + if l.contains("rank") { + json!({ + "name": v, + "type": "reranker", + }) + } else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) { + json!({ + "name": v, + "type": "embedding", + "default_chunk_size": 1000, + "max_batch_size": 100 + }) + } else if v.contains("vision") { + json!({ + "name": v, + "supports_vision": true + }) + } else { + json!({ + "name": v, + }) + } + }) + .collect(); + client_config["models"] = models.into(); + let model_name = select_model(model_names)?; + Ok(format!("{client}:{model_name}")) +} + +fn select_model(model_names: Vec) -> Result { + if model_names.is_empty() { + bail!("No models"); + } + let model = if model_names.len() == 1 { + model_names[0].clone() + } else { + Select::new("Default Model (required):", model_names).prompt()? + }; + Ok(model) +} + +fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result { + let desc = if required { + format!("{desc} (required):") + } else { + format!("{desc} (optional):") + }; + let mut text = Text::new(&desc); + if required { + text = text.with_validator(required!("This field is required")) + } + if let Some(help_message) = help_message { + text = text.with_help_message(help_message); + } + let text = text.prompt()?; + Ok(text) +} diff --git a/src/client/gemini.rs b/src/client/gemini.rs new file mode 100644 index 0000000..3c8778d --- /dev/null +++ b/src/client/gemini.rs @@ -0,0 +1,136 @@ +use super::vertexai::*; +use super::*; + +use anyhow::{Context, Result}; +use reqwest::RequestBuilder; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta"; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct GeminiConfig { + pub name: Option, + pub api_key: Option, + pub api_base: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl GeminiClient { + config_get_fn!(api_key, get_api_key); + config_get_fn!(api_base, get_api_base); + + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; +} + +impl_client_trait!( + GeminiClient, + ( + prepare_chat_completions, + gemini_chat_completions, + gemini_chat_completions_streaming + ), + (prepare_embeddings, embeddings), + (noop_prepare_rerank, noop_rerank), +); + +fn prepare_chat_completions( + self_: &GeminiClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let func = match data.stream { + true => "streamGenerateContent", + false => "generateContent", + }; + + let url = format!( + "{}/models/{}:{}", + api_base.trim_end_matches('/'), + self_.model.real_name(), + func + ); + + let body = gemini_build_chat_completions_body(data, &self_.model)?; + + let mut request_data = RequestData::new(url, body); + + request_data.header("x-goog-api-key", api_key); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!( + "{}/models/{}:batchEmbedContents?key={}", + api_base.trim_end_matches('/'), + self_.model.real_name(), + api_key + ); + + let model_id = format!("models/{}", self_.model.real_name()); + + let requests: Vec<_> = data + .texts + .iter() + .map(|text| { + json!({ + "model": model_id, + "content": { + "parts": [ + { + "text": text + } + ] + }, + }) + }) + .collect(); + + let body = json!({ + "requests": requests, + }); + + let request_data = RequestData::new(url, body); + + Ok(request_data) +} + +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + let output = res_body + .embeddings + .into_iter() + .map(|embedding| embedding.values) + .collect(); + Ok(output) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + embeddings: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyEmbedding { + values: Vec, +} diff --git a/src/client/macros.rs b/src/client/macros.rs new file mode 100644 index 0000000..a3730bd --- /dev/null +++ b/src/client/macros.rs @@ -0,0 +1,245 @@ +#[macro_export] +macro_rules! register_client { + ( + $(($module:ident, $name:literal, $config:ident, $client:ident),)+ + ) => { + $( + mod $module; + )+ + $( + use self::$module::$config; + )+ + + #[derive(Debug, Clone, serde::Deserialize)] + #[serde(tag = "type")] + pub enum ClientConfig { + $( + #[serde(rename = $name)] + $config($config), + )+ + #[serde(other)] + Unknown, + } + + $( + #[derive(Debug)] + pub struct $client { + global_config: $crate::config::GlobalConfig, + config: $config, + model: $crate::client::Model, + } + + impl $client { + pub const NAME: &'static str = $name; + + pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option> { + let config = global_config.read().clients.iter().find_map(|client_config| { + if let ClientConfig::$config(c) = client_config { + if Self::name(c) == model.client_name() { + return Some(c.clone()) + } + } + None + })?; + + Some(Box::new(Self { + global_config: global_config.clone(), + config, + model: model.clone(), + })) + } + + pub fn list_models(local_config: &$config) -> Vec { + let client_name = Self::name(local_config); + if local_config.models.is_empty() { + if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| { + v.provider == $name || + ($name == OpenAICompatibleClient::NAME + && local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default()) + }) { + return Model::from_config(client_name, &v.models); + } + vec![] + } else { + Model::from_config(client_name, &local_config.models) + } + } + + pub fn name(local_config: &$config) -> &str { + local_config.name.as_deref().unwrap_or(Self::NAME) + } + } + + )+ + + pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result> { + let model = model.unwrap_or_else(|| config.read().model.clone()); + None + $(.or_else(|| $client::init(config, &model)))+ + .ok_or_else(|| { + anyhow::anyhow!("Invalid model '{}'", model.id()) + }) + } + + pub fn list_client_types() -> Vec<&'static str> { + let mut client_types: Vec<_> = vec![$($client::NAME,)+]; + client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name)); + client_types + } + + pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> { + $( + if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME { + return create_config(&$client::PROMPTS, $client::NAME).await + } + )+ + if let Some(ret) = create_openai_compatible_client_config(client).await? { + return Ok(ret); + } + anyhow::bail!("Unknown client '{}'", client) + } + + static ALL_CLIENT_NAMES: std::sync::OnceLock> = std::sync::OnceLock::new(); + + pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> { + let names = ALL_CLIENT_NAMES.get_or_init(|| { + config + .clients + .iter() + .flat_map(|v| match v { + $(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+ + ClientConfig::Unknown => vec![], + }) + .collect() + }); + names.iter().collect() + } + + static ALL_MODELS: std::sync::OnceLock> = std::sync::OnceLock::new(); + + pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + let models = ALL_MODELS.get_or_init(|| { + config + .clients + .iter() + .flat_map(|v| match v { + $(ClientConfig::$config(c) => $client::list_models(c),)+ + ClientConfig::Unknown => vec![], + }) + .collect() + }); + models.iter().collect() + } + + pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> { + list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect() + } + }; +} + +#[macro_export] +macro_rules! client_common_fns { + () => { + fn global_config(&self) -> &$crate::config::GlobalConfig { + &self.global_config + } + + fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> { + self.config.extra.as_ref() + } + + fn patch_config(&self) -> Option<&$crate::client::RequestPatch> { + self.config.patch.as_ref() + } + + fn name(&self) -> &str { + Self::name(&self.config) + } + + fn model(&self) -> &Model { + &self.model + } + + fn model_mut(&mut self) -> &mut Model { + &mut self.model + } + }; +} + +#[macro_export] +macro_rules! impl_client_trait { + ( + $client:ident, + ($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path), + ($prepare_embeddings:path, $embeddings:path), + ($prepare_rerank:path, $rerank:path), + ) => { + #[async_trait::async_trait] + impl $crate::client::Client for $crate::client::$client { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &reqwest::Client, + data: $crate::client::ChatCompletionsData, + ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { + let request_data = $prepare_chat_completions(self, data)?; + let builder = self.request_builder(client, request_data); + $chat_completions(builder, self.model()).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &reqwest::Client, + handler: &mut $crate::client::SseHandler, + data: $crate::client::ChatCompletionsData, + ) -> Result<()> { + let request_data = $prepare_chat_completions(self, data)?; + let builder = self.request_builder(client, request_data); + $chat_completions_streaming(builder, handler, self.model()).await + } + + async fn embeddings_inner( + &self, + client: &reqwest::Client, + data: &$crate::client::EmbeddingsData, + ) -> Result<$crate::client::EmbeddingsOutput> { + let request_data = $prepare_embeddings(self, data)?; + let builder = self.request_builder(client, request_data); + $embeddings(builder, self.model()).await + } + + async fn rerank_inner( + &self, + client: &reqwest::Client, + data: &$crate::client::RerankData, + ) -> Result<$crate::client::RerankOutput> { + let request_data = $prepare_rerank(self, data)?; + let builder = self.request_builder(client, request_data); + $rerank(builder, self.model()).await + } + } + }; +} + +#[macro_export] +macro_rules! config_get_fn { + ($field_name:ident, $fn_name:ident) => { + fn $fn_name(&self) -> anyhow::Result { + let env_prefix = Self::name(&self.config); + let env_name = + format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); + std::env::var(&env_name) + .ok() + .or_else(|| self.config.$field_name.clone()) + .ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name))) + } + }; +} + +#[macro_export] +macro_rules! unsupported_model { + ($name:expr) => { + anyhow::bail!("Unsupported model '{}'", $name) + }; +} diff --git a/src/client/message.rs b/src/client/message.rs new file mode 100644 index 0000000..5ca7f78 --- /dev/null +++ b/src/client/message.rs @@ -0,0 +1,235 @@ +use super::Model; + +use crate::{function::ToolResult, multiline_text, utils::dimmed_text}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Message { + pub role: MessageRole, + pub content: MessageContent, +} + +impl Default for Message { + fn default() -> Self { + Self { + role: MessageRole::User, + content: MessageContent::Text(String::new()), + } + } +} + +impl Message { + pub fn new(role: MessageRole, content: MessageContent) -> Self { + Self { role, content } + } + + pub fn merge_system(&mut self, system: MessageContent) { + match (&mut self.content, system) { + (MessageContent::Text(text), MessageContent::Text(system_text)) => { + self.content = MessageContent::Array(vec![ + MessageContentPart::Text { text: system_text }, + MessageContentPart::Text { + text: text.to_string(), + }, + ]) + } + (MessageContent::Array(list), MessageContent::Text(system_text)) => { + list.insert(0, MessageContentPart::Text { text: system_text }) + } + (MessageContent::Text(text), MessageContent::Array(mut system_list)) => { + system_list.push(MessageContentPart::Text { + text: text.to_string(), + }); + self.content = MessageContent::Array(system_list); + } + (MessageContent::Array(list), MessageContent::Array(mut system_list)) => { + system_list.append(list); + self.content = MessageContent::Array(system_list); + } + _ => {} + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum MessageRole { + System, + Assistant, + User, + Tool, +} + +#[allow(dead_code)] +impl MessageRole { + pub fn is_system(&self) -> bool { + matches!(self, MessageRole::System) + } + + pub fn is_user(&self) -> bool { + matches!(self, MessageRole::User) + } + + pub fn is_assistant(&self) -> bool { + matches!(self, MessageRole::Assistant) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Array(Vec), + // Note: This type is primarily for convenience and does not exist in OpenAI's API. + ToolCalls(MessageContentToolCalls), +} + +impl MessageContent { + pub fn render_input( + &self, + resolve_url_fn: impl Fn(&str) -> String, + agent_info: &Option<(String, Vec)>, + ) -> String { + match self { + MessageContent::Text(text) => multiline_text(text), + MessageContent::Array(list) => { + let (mut concated_text, mut files) = (String::new(), vec![]); + for item in list { + match item { + MessageContentPart::Text { text } => { + concated_text = format!("{concated_text} {text}") + } + MessageContentPart::ImageUrl { image_url } => { + files.push(resolve_url_fn(&image_url.url)) + } + } + } + if !concated_text.is_empty() { + concated_text = format!(" -- {}", multiline_text(&concated_text)) + } + format!(".file {}{}", files.join(" "), concated_text) + } + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut lines = vec![]; + if !text.is_empty() { + lines.push(text.clone()) + } + for tool_result in tool_results { + let mut parts = vec!["Call".to_string()]; + if let Some((agent_name, functions)) = agent_info { + if functions.contains(&tool_result.call.name) { + parts.push(agent_name.clone()) + } + } + parts.push(tool_result.call.name.clone()); + parts.push(tool_result.call.arguments.to_string()); + lines.push(dimmed_text(&parts.join(" "))); + } + lines.join("\n") + } + } + } + + pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) { + match self { + MessageContent::Text(text) => *text = replace_fn(text), + MessageContent::Array(list) => { + if list.is_empty() { + list.push(MessageContentPart::Text { + text: replace_fn(""), + }) + } else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) { + *text = replace_fn(text) + } + } + MessageContent::ToolCalls(_) => {} + } + } + + pub fn to_text(&self) -> String { + match self { + MessageContent::Text(text) => text.to_string(), + MessageContent::Array(list) => { + let mut parts = vec![]; + for item in list { + if let MessageContentPart::Text { text } = item { + parts.push(text.clone()) + } + } + parts.join("\n\n") + } + MessageContent::ToolCalls(_) => String::new(), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessageContentPart { + Text { text: String }, + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageContentToolCalls { + pub tool_results: Vec, + pub text: String, + pub sequence: bool, +} + +impl MessageContentToolCalls { + pub fn new(tool_results: Vec, text: String) -> Self { + Self { + tool_results, + text, + sequence: false, + } + } + + pub fn merge(&mut self, tool_results: Vec, _text: String) { + self.tool_results.extend(tool_results); + self.text.clear(); + self.sequence = true; + } +} + +pub fn patch_messages(messages: &mut Vec, model: &Model) { + if messages.is_empty() { + return; + } + if let Some(prefix) = model.system_prompt_prefix() { + if messages[0].role.is_system() { + messages[0].merge_system(MessageContent::Text(prefix.to_string())); + } else { + messages.insert( + 0, + Message { + role: MessageRole::System, + content: MessageContent::Text(prefix.to_string()), + }, + ); + } + } + if model.no_system_message() && messages[0].role.is_system() { + let system_message = messages.remove(0); + if let (Some(message), system) = (messages.get_mut(0), system_message.content) { + message.merge_system(system); + } + } +} + +pub fn extract_system_message(messages: &mut Vec) -> Option { + if messages[0].role.is_system() { + let system_message = messages.remove(0); + return Some(system_message.content.to_text()); + } + None +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..d0a4574 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,62 @@ +mod access_token; +mod common; +mod message; +#[macro_use] +mod macros; +mod model; +mod stream; + +pub use crate::function::ToolCall; +pub use common::*; +pub use message::*; +pub use model::*; +pub use stream::*; + +register_client!( + (openai, "openai", OpenAIConfig, OpenAIClient), + ( + openai_compatible, + "openai-compatible", + OpenAICompatibleConfig, + OpenAICompatibleClient + ), + (gemini, "gemini", GeminiConfig, GeminiClient), + (claude, "claude", ClaudeConfig, ClaudeClient), + (cohere, "cohere", CohereConfig, CohereClient), + ( + azure_openai, + "azure-openai", + AzureOpenAIConfig, + AzureOpenAIClient + ), + (vertexai, "vertexai", VertexAIConfig, VertexAIClient), + (bedrock, "bedrock", BedrockConfig, BedrockClient), +); + +pub const OPENAI_COMPATIBLE_PROVIDERS: [(&str, &str); 18] = [ + ("ai21", "https://api.ai21.com/studio/v1"), + ( + "cloudflare", + "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1", + ), + ("deepinfra", "https://api.deepinfra.com/v1/openai"), + ("deepseek", "https://api.deepseek.com"), + ("ernie", "https://qianfan.baidubce.com/v2"), + ("github", "https://models.inference.ai.azure.com"), + ("groq", "https://api.groq.com/openai/v1"), + ("hunyuan", "https://api.hunyuan.cloud.tencent.com/v1"), + ("minimax", "https://api.minimax.chat/v1"), + ("mistral", "https://api.mistral.ai/v1"), + ("moonshot", "https://api.moonshot.cn/v1"), + ("openrouter", "https://openrouter.ai/api/v1"), + ("perplexity", "https://api.perplexity.ai"), + ( + "qianwen", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ), + ("xai", "https://api.x.ai/v1"), + ("zhipuai", "https://open.bigmodel.cn/api/paas/v4"), + // RAG-dedicated + ("jina", "https://api.jina.ai/v1"), + ("voyageai", "https://api.voyageai.com/v1"), +]; diff --git a/src/client/model.rs b/src/client/model.rs new file mode 100644 index 0000000..925e9c9 --- /dev/null +++ b/src/client/model.rs @@ -0,0 +1,407 @@ +use super::{ + list_all_models, list_client_names, + message::{Message, MessageContent, MessageContentPart}, + ApiPatch, MessageContentToolCalls, RequestPatch, +}; + +use crate::config::Config; +use crate::utils::{estimate_token_length, strip_think_tag}; + +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt::Display; + +const PER_MESSAGES_TOKENS: usize = 5; +const BASIS_TOKENS: usize = 2; + +#[derive(Debug, Clone)] +pub struct Model { + client_name: String, + data: ModelData, +} + +impl Default for Model { + fn default() -> Self { + Model::new("", "") + } +} + +impl Model { + pub fn new(client_name: &str, name: &str) -> Self { + Self { + client_name: client_name.into(), + data: ModelData::new(name), + } + } + + pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec { + models + .iter() + .map(|v| Model { + client_name: client_name.to_string(), + data: v.clone(), + }) + .collect() + } + + pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result { + let models = list_all_models(config); + let (client_name, model_name) = match model_id.split_once(':') { + Some((client_name, model_name)) => { + if model_name.is_empty() { + (client_name, None) + } else { + (client_name, Some(model_name)) + } + } + None => (model_id, None), + }; + match model_name { + Some(model_name) => { + if let Some(model) = models.iter().find(|v| v.id() == model_id) { + if model.model_type() == model_type { + return Ok((*model).clone()); + } else { + bail!("Model '{model_id}' is not a {model_type} model") + } + } + if list_client_names(config) + .into_iter() + .any(|v| *v == client_name) + && model_type.can_create_from_name() + { + let mut new_model = Self::new(client_name, model_name); + new_model.data.model_type = model_type.to_string(); + return Ok(new_model); + } + } + None => { + if let Some(found) = models + .iter() + .find(|v| v.client_name == client_name && v.model_type() == model_type) + { + return Ok((*found).clone()); + } + } + }; + bail!("Unknown {model_type} model '{model_id}'") + } + + pub fn id(&self) -> String { + if self.data.name.is_empty() { + self.client_name.to_string() + } else { + format!("{}:{}", self.client_name, self.data.name) + } + } + + pub fn client_name(&self) -> &str { + &self.client_name + } + + pub fn name(&self) -> &str { + &self.data.name + } + + pub fn real_name(&self) -> &str { + self.data.real_name.as_deref().unwrap_or(&self.data.name) + } + + pub fn model_type(&self) -> ModelType { + if self.data.model_type.starts_with("embed") { + ModelType::Embedding + } else if self.data.model_type.starts_with("rerank") { + ModelType::Reranker + } else { + ModelType::Chat + } + } + + pub fn data(&self) -> &ModelData { + &self.data + } + + pub fn data_mut(&mut self) -> &mut ModelData { + &mut self.data + } + + pub fn description(&self) -> String { + match self.model_type() { + ModelType::Chat => { + let ModelData { + max_input_tokens, + max_output_tokens, + input_price, + output_price, + supports_vision, + supports_function_calling, + .. + } = &self.data; + let max_input_tokens = stringify_option_value(max_input_tokens); + let max_output_tokens = stringify_option_value(max_output_tokens); + let input_price = stringify_option_value(input_price); + let output_price = stringify_option_value(output_price); + let mut capabilities = vec![]; + if *supports_vision { + capabilities.push('👁'); + }; + if *supports_function_calling { + capabilities.push('⚒'); + }; + let capabilities: String = capabilities + .into_iter() + .map(|v| format!("{v} ")) + .collect::>() + .join(""); + format!( + "{max_input_tokens:>8} / {max_output_tokens:>8} | {input_price:>6} / {output_price:>6} {capabilities:>6}" + ) + } + ModelType::Embedding => { + let ModelData { + input_price, + max_tokens_per_chunk, + max_batch_size, + .. + } = &self.data; + let max_tokens = stringify_option_value(max_tokens_per_chunk); + let max_batch = stringify_option_value(max_batch_size); + let price = stringify_option_value(input_price); + format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}") + } + ModelType::Reranker => String::new(), + } + } + + pub fn patch(&self) -> Option<&Value> { + self.data.patch.as_ref() + } + + pub fn max_input_tokens(&self) -> Option { + self.data.max_input_tokens + } + + pub fn max_output_tokens(&self) -> Option { + self.data.max_output_tokens + } + + pub fn no_stream(&self) -> bool { + self.data.no_stream + } + + pub fn no_system_message(&self) -> bool { + self.data.no_system_message + } + + pub fn system_prompt_prefix(&self) -> Option<&str> { + self.data.system_prompt_prefix.as_deref() + } + + pub fn max_tokens_per_chunk(&self) -> Option { + self.data.max_tokens_per_chunk + } + + pub fn default_chunk_size(&self) -> usize { + self.data.default_chunk_size.unwrap_or(1000) + } + + pub fn max_batch_size(&self) -> Option { + self.data.max_batch_size + } + + pub fn max_tokens_param(&self) -> Option { + if self.data.require_max_tokens { + self.data.max_output_tokens + } else { + None + } + } + + pub fn set_max_tokens( + &mut self, + max_output_tokens: Option, + require_max_tokens: bool, + ) -> &mut Self { + match max_output_tokens { + None | Some(0) => self.data.max_output_tokens = None, + _ => self.data.max_output_tokens = max_output_tokens, + } + self.data.require_max_tokens = require_max_tokens; + self + } + + pub fn messages_tokens(&self, messages: &[Message]) -> usize { + let messages_len = messages.len(); + messages + .iter() + .enumerate() + .map(|(i, v)| match &v.content { + MessageContent::Text(text) => { + if v.role.is_assistant() && i != messages_len - 1 { + estimate_token_length(&strip_think_tag(text)) + } else { + estimate_token_length(text) + } + } + MessageContent::Array(list) => list + .iter() + .map(|v| match v { + MessageContentPart::Text { text } => estimate_token_length(text), + MessageContentPart::ImageUrl { .. } => 0, + }) + .sum(), + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { + estimate_token_length(text) + + tool_results + .iter() + .map(|v| { + serde_json::to_string(v) + .map(|v| estimate_token_length(&v)) + .unwrap_or_default() + }) + .sum::() + } + }) + .sum() + } + + pub fn total_tokens(&self, messages: &[Message]) -> usize { + if messages.is_empty() { + return 0; + } + let num_messages = messages.len(); + let message_tokens = self.messages_tokens(messages); + if messages[num_messages - 1].role.is_user() { + num_messages * PER_MESSAGES_TOKENS + message_tokens + } else { + (num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens + } + } + + pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> { + let total_tokens = self.total_tokens(messages) + BASIS_TOKENS; + if let Some(max_input_tokens) = self.data.max_input_tokens { + if total_tokens >= max_input_tokens { + bail!("Exceed max_input_tokens limit") + } + } + Ok(()) + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ModelData { + pub name: String, + #[serde(default = "default_model_type", rename = "type")] + pub model_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub real_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_input_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_price: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_price: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub patch: Option, + + // chat-only properties + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub require_max_tokens: bool, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub supports_vision: bool, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub supports_function_calling: bool, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + no_stream: bool, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + no_system_message: bool, + #[serde(skip_serializing_if = "Option::is_none")] + system_prompt_prefix: Option, + + // embedding-only properties + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens_per_chunk: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub default_chunk_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_batch_size: Option, +} + +impl ModelData { + pub fn new(name: &str) -> Self { + Self { + name: name.to_string(), + model_type: default_model_type(), + ..Default::default() + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderModels { + pub provider: String, + pub models: Vec, +} + +fn default_model_type() -> String { + "chat".into() +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ModelType { + Chat, + Embedding, + Reranker, +} + +impl Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ModelType::Chat => write!(f, "chat"), + ModelType::Embedding => write!(f, "embedding"), + ModelType::Reranker => write!(f, "reranker"), + } + } +} + +impl ModelType { + pub fn can_create_from_name(self) -> bool { + match self { + ModelType::Chat => true, + ModelType::Embedding => false, + ModelType::Reranker => true, + } + } + + pub fn api_name(self) -> &'static str { + match self { + ModelType::Chat => "chat_completions", + ModelType::Embedding => "embeddings", + ModelType::Reranker => "rerank", + } + } + + pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> { + match self { + ModelType::Chat => patch.chat_completions.as_ref(), + ModelType::Embedding => patch.embeddings.as_ref(), + ModelType::Reranker => patch.rerank.as_ref(), + } + } +} + +fn stringify_option_value(value: &Option) -> String +where + T: Display, +{ + match value { + Some(value) => value.to_string(), + None => "-".to_string(), + } +} diff --git a/src/client/openai.rs b/src/client/openai.rs new file mode 100644 index 0000000..59496ef --- /dev/null +++ b/src/client/openai.rs @@ -0,0 +1,408 @@ +use super::*; + +use crate::utils::strip_think_tag; + +use anyhow::{bail, Context, Result}; +use reqwest::RequestBuilder; +use serde::Deserialize; +use serde_json::{json, Value}; + +const API_BASE: &str = "https://api.openai.com/v1"; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct OpenAIConfig { + pub name: Option, + pub api_key: Option, + pub api_base: Option, + pub organization_id: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl OpenAIClient { + config_get_fn!(api_key, get_api_key); + config_get_fn!(api_base, get_api_base); + + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; +} + +impl_client_trait!( + OpenAIClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (noop_prepare_rerank, noop_rerank), +); + +fn prepare_chat_completions( + self_: &OpenAIClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{}/chat/completions", api_base.trim_end_matches('/')); + + let body = openai_build_chat_completions_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + if let Some(organization_id) = &self_.config.organization_id { + request_data.header("OpenAI-Organization", organization_id); + } + + Ok(request_data) +} + +fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result { + let api_key = self_.get_api_key()?; + let api_base = self_ + .get_api_base() + .unwrap_or_else(|_| API_BASE.to_string()); + + let url = format!("{api_base}/embeddings"); + + let body = openai_build_embeddings_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(api_key); + if let Some(organization_id) = &self_.config.organization_id { + request_data.header("OpenAI-Organization", organization_id); + } + + Ok(request_data) +} + +pub async fn openai_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + + debug!("non-stream-data: {data}"); + openai_extract_chat_completions(&data) +} + +pub async fn openai_chat_completions_streaming( + builder: RequestBuilder, + handler: &mut SseHandler, + _model: &Model, +) -> Result<()> { + let mut call_id = String::new(); + let mut function_name = String::new(); + let mut function_arguments = String::new(); + let mut function_id = String::new(); + let mut reasoning_state = 0; + let handle = |message: SseMessage| -> Result { + if message.data == "[DONE]" { + if !function_name.is_empty() { + if function_arguments.is_empty() { + function_arguments = String::from("{}"); + } + let arguments: Value = function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + normalize_function_id(&function_id), + ))?; + } + return Ok(true); + } + let data: Value = serde_json::from_str(&message.data)?; + debug!("stream-data: {data}"); + if let Some(text) = data["choices"][0]["delta"]["content"] + .as_str() + .filter(|v| !v.is_empty()) + { + if reasoning_state == 1 { + handler.text("\n\n\n")?; + reasoning_state = 0; + } + handler.text(text)?; + } else if let Some(text) = data["choices"][0]["delta"]["reasoning_content"] + .as_str() + .or_else(|| data["choices"][0]["delta"]["reasoning"].as_str()) + .filter(|v| !v.is_empty()) + { + if reasoning_state == 0 { + handler.text("\n")?; + reasoning_state = 1; + } + handler.text(text)?; + } + if let (Some(function), index, id) = ( + data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(), + data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(), + data["choices"][0]["delta"]["tool_calls"][0]["id"] + .as_str() + .filter(|v| !v.is_empty()), + ) { + if reasoning_state == 1 { + handler.text("\n\n\n")?; + reasoning_state = 0; + } + let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default()); + if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() { + if !function_name.is_empty() { + if function_arguments.is_empty() { + function_arguments = String::from("{}"); + } + let arguments: Value = function_arguments.parse().with_context(|| { + format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'") + })?; + handler.tool_call(ToolCall::new( + function_name.clone(), + arguments, + normalize_function_id(&function_id), + ))?; + } + function_name.clear(); + function_arguments.clear(); + function_id.clear(); + call_id = maybe_call_id; + } + if let Some(name) = function.get("name").and_then(|v| v.as_str()) { + if name.starts_with(&function_name) { + function_name = name.to_string(); + } else { + function_name.push_str(name); + } + } + if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) { + function_arguments.push_str(arguments); + } + if let Some(id) = id { + function_id = id.to_string(); + } + } + Ok(false) + }; + + sse_stream(builder, handle).await +} + +pub async fn openai_embeddings( + builder: RequestBuilder, + _model: &Model, +) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + let output = res_body.data.into_iter().map(|v| v.embedding).collect(); + Ok(output) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + data: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyEmbedding { + embedding: Vec, +} + +pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value { + let ChatCompletionsData { + messages, + temperature, + top_p, + functions, + stream, + } = data; + + let messages_len = messages.len(); + let messages: Vec = messages + .into_iter() + .enumerate() + .flat_map(|(i, message)| { + let Message { role, content } = message; + match content { + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, + text: _, + sequence, + }) => { + if !sequence { + let tool_calls: Vec<_> = tool_results + .iter() + .map(|tool_result| { + json!({ + "id": tool_result.call.id, + "type": "function", + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments.to_string(), + }, + }) + }) + .collect(); + let mut messages = vec![ + json!({ "role": MessageRole::Assistant, "tool_calls": tool_calls }), + ]; + for tool_result in tool_results { + messages.push(json!({ + "role": "tool", + "content": tool_result.output.to_string(), + "tool_call_id": tool_result.call.id, + })); + } + messages + } else { + tool_results.into_iter().flat_map(|tool_result| { + vec![ + json!({ + "role": MessageRole::Assistant, + "tool_calls": [ + { + "id": tool_result.call.id, + "type": "function", + "function": { + "name": tool_result.call.name, + "arguments": tool_result.call.arguments.to_string(), + }, + } + ] + }), + json!({ + "role": "tool", + "content": tool_result.output.to_string(), + "tool_call_id": tool_result.call.id, + }) + ] + + }).collect() + } + } + MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => { + vec![json!({ "role": role, "content": strip_think_tag(&text) } + )] + } + _ => vec![json!({ "role": role, "content": content })], + } + }) + .collect(); + + let mut body = json!({ + "model": &model.real_name(), + "messages": messages, + }); + + if let Some(v) = model.max_tokens_param() { + if model + .patch() + .and_then(|v| v.get("body").and_then(|v| v.get("max_tokens"))) + == Some(&Value::Null) + { + body["max_completion_tokens"] = v.into(); + } else { + body["max_tokens"] = v.into(); + } + } + if let Some(v) = temperature { + body["temperature"] = v.into(); + } + if let Some(v) = top_p { + body["top_p"] = v.into(); + } + if stream { + body["stream"] = true.into(); + } + if let Some(functions) = functions { + body["tools"] = functions + .iter() + .map(|v| { + json!({ + "type": "function", + "function": v, + }) + }) + .collect(); + } + body +} + +pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value { + json!({ + "input": data.texts, + "model": model.real_name() + }) +} + +pub fn openai_extract_chat_completions(data: &Value) -> Result { + let text = data["choices"][0]["message"]["content"] + .as_str() + .unwrap_or_default(); + + let reasoning = data["choices"][0]["message"]["reasoning_content"] + .as_str() + .or_else(|| data["choices"][0]["message"]["reasoning"].as_str()) + .unwrap_or_default() + .trim(); + + let mut tool_calls = vec![]; + if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() { + for call in calls { + if let (Some(name), Some(arguments), Some(id)) = ( + call["function"]["name"].as_str(), + call["function"]["arguments"].as_str(), + call["id"].as_str(), + ) { + let arguments: Value = arguments.parse().with_context(|| { + format!("Tool call '{name}' have non-JSON arguments '{arguments}'") + })?; + tool_calls.push(ToolCall::new( + name.to_string(), + arguments, + Some(id.to_string()), + )); + } + } + }; + + if text.is_empty() && tool_calls.is_empty() { + bail!("Invalid response data: {data}"); + } + let text = if !reasoning.is_empty() { + format!("\n{reasoning}\n\n\n{text}") + } else { + text.to_string() + }; + let output = ChatCompletionsOutput { + text, + tool_calls, + id: data["id"].as_str().map(|v| v.to_string()), + input_tokens: data["usage"]["prompt_tokens"].as_u64(), + output_tokens: data["usage"]["completion_tokens"].as_u64(), + }; + Ok(output) +} + +fn normalize_function_id(value: &str) -> Option { + if value.is_empty() { + None + } else { + Some(value.to_string()) + } +} diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs new file mode 100644 index 0000000..4f77e88 --- /dev/null +++ b/src/client/openai_compatible.rs @@ -0,0 +1,162 @@ +use super::openai::*; +use super::*; + +use anyhow::{Context, Result}; +use reqwest::RequestBuilder; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Clone, Deserialize)] +pub struct OpenAICompatibleConfig { + pub name: Option, + pub api_base: Option, + pub api_key: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl OpenAICompatibleClient { + config_get_fn!(api_base, get_api_base); + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptAction<'static>; 0] = []; +} + +impl_client_trait!( + OpenAICompatibleClient, + ( + prepare_chat_completions, + openai_chat_completions, + openai_chat_completions_streaming + ), + (prepare_embeddings, openai_embeddings), + (prepare_rerank, generic_rerank), +); + +fn prepare_chat_completions( + self_: &OpenAICompatibleClient, + data: ChatCompletionsData, +) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; + + let url = format!("{api_base}/chat/completions"); + + let body = openai_build_chat_completions_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); + } + + Ok(request_data) +} + +fn prepare_embeddings( + self_: &OpenAICompatibleClient, + data: &EmbeddingsData, +) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; + + let url = format!("{api_base}/embeddings"); + + let body = openai_build_embeddings_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); + } + + Ok(request_data) +} + +fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result { + let api_key = self_.get_api_key().ok(); + let api_base = get_api_base_ext(self_)?; + + let url = if self_.name().starts_with("ernie") { + format!("{api_base}/rerankers") + } else { + format!("{api_base}/rerank") + }; + + let body = generic_build_rerank_body(data, &self_.model); + + let mut request_data = RequestData::new(url, body); + + if let Some(api_key) = api_key { + request_data.bearer_auth(api_key); + } + + Ok(request_data) +} + +fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result { + let api_base = match self_.get_api_base() { + Ok(v) => v, + Err(err) => { + match OPENAI_COMPATIBLE_PROVIDERS + .into_iter() + .find_map(|(name, api_base)| { + if name == self_.model.client_name() { + Some(api_base.to_string()) + } else { + None + } + }) { + Some(v) => v, + None => return Err(err), + } + } + }; + Ok(api_base.trim_end_matches('/').to_string()) +} + +pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result { + let res = builder.send().await?; + let status = res.status(); + let mut data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + if data.get("results").is_none() && data.get("data").is_some() { + if let Some(data_obj) = data.as_object_mut() { + if let Some(value) = data_obj.remove("data") { + data_obj.insert("results".to_string(), value); + } + } + } + let res_body: GenericRerankResBody = + serde_json::from_value(data).context("Invalid rerank data")?; + Ok(res_body.results) +} + +#[derive(Deserialize)] +pub struct GenericRerankResBody { + pub results: RerankOutput, +} + +pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value { + let RerankData { + query, + documents, + top_n, + } = data; + + let mut body = json!({ + "model": model.real_name(), + "query": query, + "documents": documents, + }); + if model.client_name().starts_with("voyageai") { + body["top_k"] = (*top_n).into() + } else { + body["top_n"] = (*top_n).into() + } + body +} diff --git a/src/client/stream.rs b/src/client/stream.rs new file mode 100644 index 0000000..c7d7081 --- /dev/null +++ b/src/client/stream.rs @@ -0,0 +1,296 @@ +use super::{catch_error, ToolCall}; +use crate::utils::AbortSignal; + +use anyhow::{anyhow, bail, Context, Result}; +use futures_util::{Stream, StreamExt}; +use reqwest::RequestBuilder; +use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; +use serde_json::Value; +use tokio::sync::mpsc::UnboundedSender; + +pub struct SseHandler { + sender: UnboundedSender, + abort_signal: AbortSignal, + buffer: String, + tool_calls: Vec, +} + +impl SseHandler { + pub fn new(sender: UnboundedSender, abort_signal: AbortSignal) -> Self { + Self { + sender, + abort_signal, + buffer: String::new(), + tool_calls: Vec::new(), + } + } + + pub fn text(&mut self, text: &str) -> Result<()> { + // debug!("HandleText: {}", text); + if text.is_empty() { + return Ok(()); + } + self.buffer.push_str(text); + let ret = self + .sender + .send(SseEvent::Text(text.to_string())) + .with_context(|| "Failed to send SseEvent:Text"); + if let Err(err) = ret { + if self.abort_signal.aborted() { + return Ok(()); + } + return Err(err); + } + Ok(()) + } + + pub fn done(&mut self) { + // debug!("HandleDone"); + let ret = self.sender.send(SseEvent::Done); + if ret.is_err() { + if self.abort_signal.aborted() { + return; + } + warn!("Failed to send SseEvent:Done"); + } + } + + pub fn tool_call(&mut self, call: ToolCall) -> Result<()> { + // debug!("HandleCall: {:?}", call); + self.tool_calls.push(call); + Ok(()) + } + + pub fn abort(&self) -> AbortSignal { + self.abort_signal.clone() + } + + pub fn tool_calls(&self) -> &[ToolCall] { + &self.tool_calls + } + + pub fn take(self) -> (String, Vec) { + let Self { + buffer, tool_calls, .. + } = self; + (buffer, tool_calls) + } +} + +#[derive(Debug)] +pub enum SseEvent { + Text(String), + Done, +} + +#[derive(Debug)] +pub struct SseMessage { + #[allow(unused)] + pub event: String, + pub data: String, +} + +pub async fn sse_stream(builder: RequestBuilder, mut handle: F) -> Result<()> +where + F: FnMut(SseMessage) -> Result, +{ + let mut es = builder.eventsource()?; + while let Some(event) = es.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(message)) => { + let message = SseMessage { + event: message.event, + data: message.data, + }; + if handle(message)? { + break; + } + } + Err(err) => { + match err { + EventSourceError::StreamEnded => {} + EventSourceError::InvalidStatusCode(status, res) => { + let text = res.text().await?; + let data: Value = match text.parse() { + Ok(data) => data, + Err(_) => { + bail!( + "Invalid response data: {text} (status: {})", + status.as_u16() + ); + } + }; + catch_error(&data, status.as_u16())?; + } + EventSourceError::InvalidContentType(header_value, res) => { + let text = res.text().await?; + bail!( + "Invalid response event-stream. content-type: {}, data: {text}", + header_value.to_str().unwrap_or_default() + ); + } + _ => { + bail!("{}", err); + } + } + es.close(); + } + } + } + Ok(()) +} + +pub async fn json_stream(mut stream: S, mut handle: F) -> Result<()> +where + S: Stream> + Unpin, + F: FnMut(&str) -> Result<()>, + E: std::error::Error, +{ + let mut parser = JsonStreamParser::default(); + let mut unparsed_bytes = vec![]; + while let Some(chunk_bytes) = stream.next().await { + let chunk_bytes = + chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?; + unparsed_bytes.extend(chunk_bytes); + match std::str::from_utf8(&unparsed_bytes) { + Ok(text) => { + parser.process(text, &mut handle)?; + unparsed_bytes.clear(); + } + Err(_) => { + continue; + } + } + } + if !unparsed_bytes.is_empty() { + let text = std::str::from_utf8(&unparsed_bytes)?; + parser.process(text, &mut handle)?; + } + + Ok(()) +} + +#[derive(Debug, Default)] +struct JsonStreamParser { + buffer: Vec, + cursor: usize, + start: Option, + balances: Vec, + quoting: bool, + escape: bool, +} + +impl JsonStreamParser { + fn process(&mut self, text: &str, handle: &mut F) -> Result<()> + where + F: FnMut(&str) -> Result<()>, + { + self.buffer.extend(text.chars()); + + for i in self.cursor..self.buffer.len() { + let ch = self.buffer[i]; + if self.quoting { + if ch == '\\' { + self.escape = !self.escape; + } else { + if !self.escape && ch == '"' { + self.quoting = false; + } + self.escape = false; + } + continue; + } + match ch { + '"' => { + self.quoting = true; + self.escape = false; + } + '{' => { + if self.balances.is_empty() { + self.start = Some(i); + } + self.balances.push(ch); + } + '[' => { + if self.start.is_some() { + self.balances.push(ch); + } + } + '}' => { + self.balances.pop(); + if self.balances.is_empty() { + if let Some(start) = self.start.take() { + let value: String = self.buffer[start..=i].iter().collect(); + handle(&value)?; + } + } + } + ']' => { + self.balances.pop(); + } + _ => {} + } + } + self.cursor = self.buffer.len(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use bytes::Bytes; + use futures_util::stream; + use rand::Rng; + + fn split_chunks(text: &str) -> Vec> { + let mut rng = rand::rng(); + let len = text.len(); + let cut1 = rng.random_range(1..len - 1); + let cut2 = rng.random_range(cut1 + 1..len); + let chunk1 = text.as_bytes()[..cut1].to_vec(); + let chunk2 = text.as_bytes()[cut1..cut2].to_vec(); + let chunk3 = text.as_bytes()[cut2..].to_vec(); + vec![chunk1, chunk2, chunk3] + } + + macro_rules! assert_json_stream { + ($input:expr, $output:expr) => { + let chunks: Vec<_> = split_chunks($input) + .into_iter() + .map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk))) + .collect(); + let stream = stream::iter(chunks); + let mut output = vec![]; + let ret = json_stream(stream, |data| { + output.push(data.to_string()); + Ok(()) + }) + .await; + assert!(ret.is_ok()); + assert_eq!($output.replace("\r\n", "\n"), output.join("\n")) + }; + } + + #[tokio::test] + async fn test_json_stream_ndjson() { + let data = r#"{"key": "value"} +{"key": "value2"} +{"key": "value3"}"#; + assert_json_stream!(data, data); + } + + #[tokio::test] + async fn test_json_stream_array() { + let input = r#"[ +{"key": "value"}, +{"key": "value2"}, +{"key": "value3"},"#; + let output = r#"{"key": "value"} +{"key": "value2"} +{"key": "value3"}"#; + assert_json_stream!(input, output); + } +} diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs new file mode 100644 index 0000000..16b55d1 --- /dev/null +++ b/src/client/vertexai.rs @@ -0,0 +1,537 @@ +use super::access_token::*; +use super::claude::*; +use super::openai::*; +use super::*; + +use anyhow::{anyhow, bail, Context, Result}; +use chrono::{Duration, Utc}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::{path::PathBuf, str::FromStr}; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct VertexAIConfig { + pub name: Option, + pub project_id: Option, + pub location: Option, + pub adc_file: Option, + #[serde(default)] + pub models: Vec, + pub patch: Option, + pub extra: Option, +} + +impl VertexAIClient { + config_get_fn!(project_id, get_project_id); + config_get_fn!(location, get_location); + + pub const PROMPTS: [PromptAction<'static>; 2] = [ + ("project_id", "Project ID", None), + ("location", "Location", None), + ]; +} + +#[async_trait::async_trait] +impl Client for VertexAIClient { + client_common_fns!(); + + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; + let model = self.model(); + let model_category = ModelCategory::from_str(model.real_name())?; + let request_data = prepare_chat_completions(self, data, &model_category)?; + let builder = self.request_builder(client, request_data); + match model_category { + ModelCategory::Gemini => gemini_chat_completions(builder, model).await, + ModelCategory::Claude => claude_chat_completions(builder, model).await, + ModelCategory::Mistral => openai_chat_completions(builder, model).await, + } + } + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()> { + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; + let model = self.model(); + let model_category = ModelCategory::from_str(model.real_name())?; + let request_data = prepare_chat_completions(self, data, &model_category)?; + let builder = self.request_builder(client, request_data); + match model_category { + ModelCategory::Gemini => { + gemini_chat_completions_streaming(builder, handler, model).await + } + ModelCategory::Claude => { + claude_chat_completions_streaming(builder, handler, model).await + } + ModelCategory::Mistral => { + openai_chat_completions_streaming(builder, handler, model).await + } + } + } + + async fn embeddings_inner( + &self, + client: &ReqwestClient, + data: &EmbeddingsData, + ) -> Result>> { + prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; + let request_data = prepare_embeddings(self, data)?; + let builder = self.request_builder(client, request_data); + embeddings(builder, self.model()).await + } +} + +fn prepare_chat_completions( + self_: &VertexAIClient, + data: ChatCompletionsData, + model_category: &ModelCategory, +) -> Result { + let project_id = self_.get_project_id()?; + let location = self_.get_location()?; + let access_token = get_access_token(self_.name())?; + + let base_url = if location == "global" { + format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers") + } else { + format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers") + }; + + let model_name = self_.model.real_name(); + + let url = match model_category { + ModelCategory::Gemini => { + let func = match data.stream { + true => "streamGenerateContent", + false => "generateContent", + }; + format!("{base_url}/google/models/{model_name}:{func}") + } + ModelCategory::Claude => { + format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") + } + ModelCategory::Mistral => { + let func = match data.stream { + true => "streamRawPredict", + false => "rawPredict", + }; + format!("{base_url}/mistralai/models/{model_name}:{func}") + } + }; + + let body = match model_category { + ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?, + ModelCategory::Claude => { + let mut body = claude_build_chat_completions_body(data, &self_.model)?; + if let Some(body_obj) = body.as_object_mut() { + body_obj.remove("model"); + } + body["anthropic_version"] = "vertex-2023-10-16".into(); + body + } + ModelCategory::Mistral => { + let mut body = openai_build_chat_completions_body(data, &self_.model); + if let Some(body_obj) = body.as_object_mut() { + body_obj["model"] = strip_model_version(self_.model.real_name()).into(); + } + body + } + }; + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(access_token); + + Ok(request_data) +} + +fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result { + let project_id = self_.get_project_id()?; + let location = self_.get_location()?; + let access_token = get_access_token(self_.name())?; + + let base_url = if location == "global" { + format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers") + } else { + format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers") + }; + let url = format!( + "{base_url}/google/models/{}:predict", + self_.model.real_name() + ); + + let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect(); + + let body = json!({ + "instances": instances, + }); + + let mut request_data = RequestData::new(url, body); + + request_data.bearer_auth(access_token); + + Ok(request_data) +} + +pub async fn gemini_chat_completions( + builder: RequestBuilder, + _model: &Model, +) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + debug!("non-stream-data: {data}"); + gemini_extract_chat_completions_text(&data) +} + +pub async fn gemini_chat_completions_streaming( + builder: RequestBuilder, + handler: &mut SseHandler, + _model: &Model, +) -> Result<()> { + let res = builder.send().await?; + let status = res.status(); + if !status.is_success() { + let data: Value = res.json().await?; + catch_error(&data, status.as_u16())?; + } else { + let handle = |value: &str| -> Result<()> { + let data: Value = serde_json::from_str(value)?; + debug!("stream-data: {data}"); + if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { + for (i, part) in parts.iter().enumerate() { + if let Some(text) = part["text"].as_str() { + if i > 0 { + handler.text("\n\n")?; + } + handler.text(text)?; + } else if let (Some(name), Some(args)) = ( + part["functionCall"]["name"].as_str(), + part["functionCall"]["args"].as_object(), + ) { + handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?; + } + } + } else if let Some("SAFETY") = data["promptFeedback"]["blockReason"] + .as_str() + .or_else(|| data["candidates"][0]["finishReason"].as_str()) + { + bail!("Blocked due to safety") + } + + Ok(()) + }; + json_stream(res.bytes_stream(), handle).await?; + } + Ok(()) +} + +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + let output = res_body + .predictions + .into_iter() + .map(|v| v.embeddings.values) + .collect(); + Ok(output) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + predictions: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyPrediction { + embeddings: EmbeddingsResBodyPredictionEmbeddings, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyPredictionEmbeddings { + values: Vec, +} + +fn gemini_extract_chat_completions_text(data: &Value) -> Result { + let mut text_parts = vec![]; + let mut tool_calls = vec![]; + if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { + for part in parts { + if let Some(text) = part["text"].as_str() { + text_parts.push(text); + } + if let (Some(name), Some(args)) = ( + part["functionCall"]["name"].as_str(), + part["functionCall"]["args"].as_object(), + ) { + tool_calls.push(ToolCall::new(name.to_string(), json!(args), None)); + } + } + } + + let text = text_parts.join("\n\n"); + if text.is_empty() && tool_calls.is_empty() { + if let Some("SAFETY") = data["promptFeedback"]["blockReason"] + .as_str() + .or_else(|| data["candidates"][0]["finishReason"].as_str()) + { + bail!("Blocked due to safety") + } else { + bail!("Invalid response data: {data}"); + } + } + let output = ChatCompletionsOutput { + text, + tool_calls, + id: None, + input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(), + output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(), + }; + Ok(output) +} + +pub fn gemini_build_chat_completions_body( + data: ChatCompletionsData, + model: &Model, +) -> Result { + let ChatCompletionsData { + mut messages, + temperature, + top_p, + functions, + stream: _, + } = data; + + let system_message = extract_system_message(&mut messages); + + let mut network_image_urls = vec![]; + let contents: Vec = messages + .into_iter() + .flat_map(|message| { + let Message { role, content } = message; + let role = match role { + MessageRole::User => "user", + _ => "model", + }; + match content { + MessageContent::Text(text) => vec![json!({ + "role": role, + "parts": [{ "text": text }] + })], + MessageContent::Array(list) => { + let parts: Vec = list + .into_iter() + .map(|item| match item { + MessageContentPart::Text { text } => json!({"text": text}), + MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => { + if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) { + json!({ "inline_data": { "mime_type": mime_type, "data": data } }) + } else { + network_image_urls.push(url.clone()); + json!({ "url": url }) + } + }, + }) + .collect(); + vec![json!({ "role": role, "parts": parts })] + }, + MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => { + let model_parts: Vec = tool_results.iter().map(|tool_result| { + json!({ + "functionCall": { + "name": tool_result.call.name, + "args": tool_result.call.arguments, + } + }) + }).collect(); + let function_parts: Vec = tool_results.into_iter().map(|tool_result| { + json!({ + "functionResponse": { + "name": tool_result.call.name, + "response": { + "name": tool_result.call.name, + "content": tool_result.output, + } + } + }) + }).collect(); + vec![ + json!({ "role": "model", "parts": model_parts }), + json!({ "role": "function", "parts": function_parts }), + ] + } + } + }) + .collect(); + + if !network_image_urls.is_empty() { + bail!( + "The model does not support network images: {:?}", + network_image_urls + ); + } + + let mut body = json!({ "contents": contents, "generationConfig": {} }); + + if let Some(v) = system_message { + body["systemInstruction"] = json!({ "parts": [{"text": v }] }); + } + + if let Some(v) = model.max_tokens_param() { + body["generationConfig"]["maxOutputTokens"] = v.into(); + } + if let Some(v) = temperature { + body["generationConfig"]["temperature"] = v.into(); + } + if let Some(v) = top_p { + body["generationConfig"]["topP"] = v.into(); + } + + if let Some(functions) = functions { + // Gemini doesn't support functions with parameters that have empty properties, so we need to patch it. + let function_declarations: Vec<_> = functions + .into_iter() + .map(|function| { + if function.parameters.is_empty_properties() { + json!({ + "name": function.name, + "description": function.description, + }) + } else { + json!(function) + } + }) + .collect(); + body["tools"] = json!([{ "functionDeclarations": function_declarations }]); + } + + Ok(body) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelCategory { + Gemini, + Claude, + Mistral, +} + +impl FromStr for ModelCategory { + type Err = anyhow::Error; + + fn from_str(s: &str) -> std::result::Result { + if s.starts_with("gemini") { + Ok(ModelCategory::Gemini) + } else if s.starts_with("claude") { + Ok(ModelCategory::Claude) + } else if s.starts_with("mistral") || s.starts_with("codestral") { + Ok(ModelCategory::Mistral) + } else { + unsupported_model!(s) + } + } +} + +pub async fn prepare_gcloud_access_token( + client: &reqwest::Client, + client_name: &str, + adc_file: &Option, +) -> Result<()> { + if !is_valid_access_token(client_name) { + let (token, expires_in) = fetch_access_token(client, adc_file) + .await + .with_context(|| "Failed to fetch access token")?; + let expires_at = Utc::now() + + Duration::try_seconds(expires_in) + .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; + set_access_token(client_name, token, expires_at.timestamp()) + } + Ok(()) +} + +async fn fetch_access_token( + client: &reqwest::Client, + file: &Option, +) -> Result<(String, i64)> { + let credentials = load_adc(file).await?; + let value: Value = client + .post("https://oauth2.googleapis.com/token") + .json(&credentials) + .send() + .await? + .json() + .await?; + + if let (Some(access_token), Some(expires_in)) = + (value["access_token"].as_str(), value["expires_in"].as_i64()) + { + Ok((access_token.to_string(), expires_in)) + } else if let Some(err_msg) = value["error_description"].as_str() { + bail!("{err_msg}") + } else { + bail!("Invalid response data: {value}") + } +} + +async fn load_adc(file: &Option) -> Result { + let adc_file = file + .as_ref() + .map(PathBuf::from) + .or_else(default_adc_file) + .ok_or_else(|| anyhow!("No application_default_credentials.json"))?; + let data = tokio::fs::read_to_string(adc_file).await?; + let data: Value = serde_json::from_str(&data)?; + if let (Some(client_id), Some(client_secret), Some(refresh_token)) = ( + data["client_id"].as_str(), + data["client_secret"].as_str(), + data["refresh_token"].as_str(), + ) { + Ok(json!({ + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + "grant_type": "refresh_token", + })) + } else { + bail!("Invalid application_default_credentials.json") + } +} + +#[cfg(not(windows))] +fn default_adc_file() -> Option { + let mut path = dirs::home_dir()?; + path.push(".config"); + path.push("gcloud"); + path.push("application_default_credentials.json"); + Some(path) +} + +#[cfg(windows)] +fn default_adc_file() -> Option { + let mut path = dirs::config_dir()?; + path.push("gcloud"); + path.push("application_default_credentials.json"); + Some(path) +} + +fn strip_model_version(name: &str) -> &str { + match name.split_once('@') { + Some((v, _)) => v, + None => name, + } +} diff --git a/src/config/agent.rs b/src/config/agent.rs new file mode 100644 index 0000000..3b35f1d --- /dev/null +++ b/src/config/agent.rs @@ -0,0 +1,570 @@ +use super::*; + +use crate::{ + client::Model, + function::{run_llm_function, Functions}, +}; + +use anyhow::{Context, Result}; +use inquire::{validator::Validation, Text}; +use std::{fs::read_to_string, path::Path}; + +use serde::{Deserialize, Serialize}; + +const DEFAULT_AGENT_NAME: &str = "rag"; + +pub type AgentVariables = IndexMap; + +#[derive(Debug, Clone)] +pub struct Agent { + name: String, + config: AgentConfig, + shared_variables: AgentVariables, + session_variables: Option, + shared_dynamic_instructions: Option, + session_dynamic_instructions: Option, + functions: Functions, + rag: Option>, + model: Model, +} + +impl Agent { + pub async fn init( + config: &GlobalConfig, + name: &str, + abort_signal: AbortSignal, + ) -> Result { + let agent_data_dir = Config::agent_data_dir(name); + let loaders = config.read().document_loaders.clone(); + let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME); + let config_path = Config::agent_config_file(name); + let mut agent_config = if config_path.exists() { + AgentConfig::load(&config_path)? + } else { + bail!("Agent config file not found at '{}'", config_path.display()) + }; + let mut functions = Functions::init_agent(name, &agent_config.global_tools)?; + + config.write().functions.clear_mcp_meta_functions(); + let mcp_servers = + (!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(",")); + let registry = config + .write() + .mcp_registry + .take() + .expect("MCP registry should be initialized"); + let new_mcp_registry = + McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; + + if !new_mcp_registry.is_empty() { + functions.append_mcp_meta_functions(new_mcp_registry.list_servers()); + } + + config.write().mcp_registry = Some(new_mcp_registry); + agent_config.replace_tools_placeholder(&functions); + + agent_config.load_envs(&config.read()); + + let model = { + let config = config.read(); + match agent_config.model_id.as_ref() { + Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?, + None => { + if agent_config.temperature.is_none() { + agent_config.temperature = config.temperature; + } + if agent_config.top_p.is_none() { + agent_config.top_p = config.top_p; + } + config.current_model().clone() + } + } + }; + + let rag = if rag_path.exists() { + Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?)) + } else if !agent_config.documents.is_empty() && !config.read().info_flag { + let mut ans = false; + if *IS_STDOUT_TERMINAL { + ans = Confirm::new("The agent has documents attached, init RAG?") + .with_default(true) + .prompt()?; + } + if ans { + let mut document_paths = vec![]; + for path in &agent_config.documents { + if is_url(path) { + document_paths.push(path.to_string()); + } else if is_loader_protocol(&loaders, path) { + let (protocol, document_path) = path + .split_once(':') + .with_context(|| "Invalid loader protocol path")?; + let resolved_path = resolve_home_dir(document_path); + let new_path = if Path::new(&resolved_path).is_relative() { + safe_join_path(&agent_data_dir, resolved_path) + .ok_or_else(|| anyhow!("Invalid document path: '{path}'"))? + } else { + PathBuf::from(&resolved_path) + }; + document_paths.push(format!("{}:{}", protocol, new_path.display())); + } else if Path::new(&resolve_home_dir(path)).is_relative() { + let new_path = safe_join_path(&agent_data_dir, path) + .ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?; + document_paths.push(new_path.display().to_string()) + } else { + document_paths.push(path.to_string()) + } + } + let rag = + Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?; + Some(Arc::new(rag)) + } else { + None + } + } else { + None + }; + + Ok(Self { + name: name.to_string(), + config: agent_config, + shared_variables: Default::default(), + session_variables: None, + shared_dynamic_instructions: None, + session_dynamic_instructions: None, + functions, + rag, + model, + }) + } + + pub fn init_agent_variables( + agent_variables: &[AgentVariable], + no_interaction: bool, + ) -> Result { + let mut output = IndexMap::new(); + if agent_variables.is_empty() { + return Ok(output); + } + let mut printed = false; + let mut unset_variables = vec![]; + for agent_variable in agent_variables { + let key = agent_variable.name.clone(); + if let Some(value) = agent_variable.default.clone() { + output.insert(key, value); + continue; + } + if no_interaction { + continue; + } + if *IS_STDOUT_TERMINAL { + if !printed { + println!("⚙ Init agent variables..."); + printed = true; + } + let value = Text::new(&format!( + "{} ({}):", + agent_variable.name, agent_variable.description + )) + .with_validator(|input: &str| { + if input.trim().is_empty() { + Ok(Validation::Invalid("This field is required".into())) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + output.insert(key, value); + } else { + unset_variables.push(agent_variable) + } + } + if !unset_variables.is_empty() { + bail!( + "The following agent variables are required:\n{}", + unset_variables + .iter() + .map(|v| format!(" - {}: {}", v.name, v.description)) + .collect::>() + .join("\n") + ) + } + Ok(output) + } + + pub fn export(&self) -> Result { + let mut value = json!({}); + value["name"] = json!(self.name()); + let variables = self.variables(); + if !variables.is_empty() { + value["variables"] = serde_json::to_value(variables)?; + } + value["config"] = json!(self.config); + let mut config = self.config.clone(); + config.instructions = self.interpolated_instructions(); + value["definition"] = json!(config); + value["data_dir"] = Config::agent_data_dir(&self.name) + .display() + .to_string() + .into(); + value["config_file"] = Config::agent_config_file(&self.name) + .display() + .to_string() + .into(); + let data = serde_yaml::to_string(&value)?; + Ok(data) + } + + pub fn banner(&self) -> String { + self.config.banner() + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn functions(&self) -> &Functions { + &self.functions + } + + pub fn rag(&self) -> Option> { + self.rag.clone() + } + + pub fn conversation_starters(&self) -> &[String] { + &self.config.conversation_starters + } + + pub fn interpolated_instructions(&self) -> String { + let mut output = self + .session_dynamic_instructions + .clone() + .or_else(|| self.shared_dynamic_instructions.clone()) + .unwrap_or_else(|| self.config.instructions.clone()); + for (k, v) in self.variables() { + output = output.replace(&format!("{{{{{k}}}}}"), v) + } + interpolate_variables(&mut output); + output + } + + pub fn agent_prelude(&self) -> Option<&str> { + self.config.agent_prelude.as_deref() + } + + pub fn variables(&self) -> &AgentVariables { + match &self.session_variables { + Some(variables) => variables, + None => &self.shared_variables, + } + } + + pub fn variable_envs(&self) -> HashMap { + self.variables() + .iter() + .map(|(k, v)| { + ( + format!("LLM_AGENT_VAR_{}", normalize_env_name(k)), + v.clone(), + ) + }) + .collect() + } + + pub fn shared_variables(&self) -> &AgentVariables { + &self.shared_variables + } + + pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) { + self.shared_variables = shared_variables; + } + + pub fn set_session_variables(&mut self, session_variables: AgentVariables) { + self.session_variables = Some(session_variables); + } + + pub fn defined_variables(&self) -> &[AgentVariable] { + &self.config.variables + } + + pub fn exit_session(&mut self) { + self.session_variables = None; + self.session_dynamic_instructions = None; + } + + pub fn is_dynamic_instructions(&self) -> bool { + self.config.dynamic_instructions + } + + pub fn update_shared_dynamic_instructions(&mut self, force: bool) -> Result<()> { + if self.is_dynamic_instructions() && (force || self.shared_dynamic_instructions.is_none()) { + self.shared_dynamic_instructions = Some(self.run_instructions_fn()?); + } + Ok(()) + } + + pub fn update_session_dynamic_instructions(&mut self, value: Option) -> Result<()> { + if self.is_dynamic_instructions() { + let value = match value { + Some(v) => v, + None => self.run_instructions_fn()?, + }; + self.session_dynamic_instructions = Some(value); + } + Ok(()) + } + + fn run_instructions_fn(&self) -> Result { + let value = run_llm_function( + self.name().to_string(), + vec!["_instructions".into(), "{}".into()], + self.variable_envs(), + )?; + match value { + Some(v) => Ok(v), + _ => bail!("No return value from '_instructions' function"), + } + } +} + +impl RoleLike for Agent { + fn to_role(&self) -> Role { + let prompt = self.interpolated_instructions(); + let mut role = Role::new("", &prompt); + role.sync(self); + role + } + + fn model(&self) -> &Model { + &self.model + } + + fn temperature(&self) -> Option { + self.config.temperature + } + + fn top_p(&self) -> Option { + self.config.top_p + } + + fn use_tools(&self) -> Option { + self.config.global_tools.clone().join(",").into() + } + + fn use_mcp_servers(&self) -> Option { + self.config.mcp_servers.clone().join(",").into() + } + + fn set_model(&mut self, model: Model) { + self.config.model_id = Some(model.id()); + self.model = model; + } + + fn set_temperature(&mut self, value: Option) { + self.config.temperature = value; + } + + fn set_top_p(&mut self, value: Option) { + self.config.top_p = value; + } + + fn set_use_tools(&mut self, value: Option) { + match value { + Some(tools) => { + let tools = tools + .split(',') + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()) + .collect::>(); + self.config.global_tools = tools; + } + None => { + self.config.global_tools.clear(); + } + } + } + + fn set_use_mcp_servers(&mut self, value: Option) { + match value { + Some(servers) => { + let servers = servers + .split(',') + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()) + .collect::>(); + self.config.mcp_servers = servers; + } + None => { + self.config.mcp_servers.clear(); + } + } + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AgentConfig { + pub name: String, + #[serde(rename(serialize = "model", deserialize = "model"))] + pub model_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_prelude: Option, + #[serde(default)] + pub description: String, + #[serde(default)] + pub version: String, + #[serde(default)] + pub mcp_servers: Vec, + #[serde(default)] + pub global_tools: Vec, + #[serde(default)] + pub instructions: String, + #[serde(default)] + pub dynamic_instructions: bool, + #[serde(default)] + pub variables: Vec, + #[serde(default)] + pub conversation_starters: Vec, + #[serde(default)] + pub documents: Vec, +} + +impl AgentConfig { + pub fn load(path: &Path) -> Result { + let contents = read_to_string(path) + .with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?; + let agent_config: Self = serde_yaml::from_str(&contents) + .with_context(|| format!("Failed to load agent config at '{}'", path.display()))?; + + Ok(agent_config) + } + + fn load_envs(&mut self, config: &Config) { + let name = &self.name; + let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}")); + + if self.agent_prelude.is_none() { + self.agent_prelude = config.agent_prelude.clone(); + } + + if let Some(v) = read_env_value::(&with_prefix("model")) { + self.model_id = v; + } + if let Some(v) = read_env_value::(&with_prefix("temperature")) { + self.temperature = v; + } + if let Some(v) = read_env_value::(&with_prefix("top_p")) { + self.top_p = v; + } + if let Some(v) = read_env_value::(&with_prefix("agent_prelude")) { + self.agent_prelude = v; + } + if let Ok(v) = env::var(with_prefix("variables")) { + if let Ok(v) = serde_json::from_str(&v) { + self.variables = v; + } + } + } + + fn banner(&self) -> String { + let AgentConfig { + name, + description, + version, + conversation_starters, + .. + } = self; + let starters = if conversation_starters.is_empty() { + String::new() + } else { + let starters = conversation_starters + .iter() + .map(|v| format!("- {v}")) + .collect::>() + .join("\n"); + format!( + r#" + +## Conversation Starters +{starters}"# + ) + }; + format!( + r#"# {name} {version} +{description}{starters}"# + ) + } + + fn replace_tools_placeholder(&mut self, functions: &Functions) { + let tools_placeholder: &str = "{{__tools__}}"; + if self.instructions.contains(tools_placeholder) { + let tools = functions + .declarations() + .iter() + .enumerate() + .map(|(i, v)| { + let description = match v.description.split_once('\n') { + Some((v, _)) => v, + None => &v.description, + }; + format!("{}. {}: {description}", i + 1, v.name) + }) + .collect::>() + .join("\n"); + self.instructions = self.instructions.replace(tools_placeholder, &tools); + } + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AgentVariable { + pub name: String, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_deserializing, default)] + pub value: String, +} + +pub fn list_agents() -> Vec { + let agents_file = Config::config_dir().join("agents.txt"); + let contents = match read_to_string(agents_file) { + Ok(v) => v, + Err(_) => return vec![], + }; + contents + .split('\n') + .filter_map(|line| { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + None + } else { + Some(line.to_string()) + } + }) + .collect() +} + +pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option)> { + let config_path = Config::agent_config_file(agent_name); + if !config_path.exists() { + return vec![]; + } + let Ok(config) = AgentConfig::load(&config_path) else { + return vec![]; + }; + config + .variables + .iter() + .map(|v| { + let description = match &v.default { + Some(default) => format!("{} [default: {default}]", v.description), + None => v.description.clone(), + }; + (format!("{}=", v.name), Some(description)) + }) + .collect() +} diff --git a/src/config/input.rs b/src/config/input.rs new file mode 100644 index 0000000..c35ff4c --- /dev/null +++ b/src/config/input.rs @@ -0,0 +1,545 @@ +use super::*; + +use crate::client::{ + init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent, + MessageContentPart, MessageContentToolCalls, MessageRole, Model, +}; +use crate::function::ToolResult; +use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal}; + +use anyhow::{bail, Context, Result}; +use indexmap::IndexSet; +use std::{collections::HashMap, fs::File, io::Read}; +use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; + +const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"]; +const SUMMARY_MAX_WIDTH: usize = 80; + +#[derive(Debug, Clone)] +pub struct Input { + config: GlobalConfig, + text: String, + raw: (String, Vec), + patched_text: Option, + last_reply: Option, + continue_output: Option, + regenerate: bool, + medias: Vec, + data_urls: HashMap, + tool_calls: Option, + role: Role, + rag_name: Option, + with_session: bool, + with_agent: bool, +} + +impl Input { + pub fn from_str(config: &GlobalConfig, text: &str, role: Option) -> Self { + let (role, with_session, with_agent) = resolve_role(&config.read(), role); + Self { + config: config.clone(), + text: text.to_string(), + raw: (text.to_string(), vec![]), + patched_text: None, + last_reply: None, + continue_output: None, + regenerate: false, + medias: Default::default(), + data_urls: Default::default(), + tool_calls: None, + role, + rag_name: None, + with_session, + with_agent, + } + } + + pub async fn from_files( + config: &GlobalConfig, + raw_text: &str, + paths: Vec, + role: Option, + ) -> Result { + let loaders = config.read().document_loaders.clone(); + let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) = + resolve_paths(&loaders, paths)?; + let mut last_reply = None; + let (documents, medias, data_urls) = load_documents( + &loaders, + local_paths, + remote_urls, + external_cmds, + protocol_paths, + ) + .await + .context("Failed to load files")?; + let mut texts = vec![]; + if !raw_text.is_empty() { + texts.push(raw_text.to_string()); + }; + if with_last_reply { + if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() { + if !output.is_empty() { + last_reply = Some(output.clone()) + } else if let Some(v) = input.last_reply.as_ref() { + last_reply = Some(v.clone()); + } + if let Some(v) = last_reply.clone() { + texts.push(format!("\n{v}")); + } + } + if last_reply.is_none() && documents.is_empty() && medias.is_empty() { + bail!("No last reply found"); + } + } + let documents_len = documents.len(); + for (kind, path, contents) in documents { + if documents_len == 1 && raw_text.is_empty() { + texts.push(format!("\n{contents}")); + } else { + texts.push(format!( + "\n============ {kind}: {path} ============\n{contents}" + )); + } + } + let (role, with_session, with_agent) = resolve_role(&config.read(), role); + Ok(Self { + config: config.clone(), + text: texts.join("\n"), + raw: (raw_text.to_string(), raw_paths), + patched_text: None, + last_reply, + continue_output: None, + regenerate: false, + medias, + data_urls, + tool_calls: Default::default(), + role, + rag_name: None, + with_session, + with_agent, + }) + } + + pub async fn from_files_with_spinner( + config: &GlobalConfig, + raw_text: &str, + paths: Vec, + role: Option, + abort_signal: AbortSignal, + ) -> Result { + abortable_run_with_spinner( + Input::from_files(config, raw_text, paths, role), + "Loading files", + abort_signal, + ) + .await + } + + pub fn is_empty(&self) -> bool { + self.text.is_empty() && self.medias.is_empty() + } + + pub fn data_urls(&self) -> HashMap { + self.data_urls.clone() + } + + pub fn tool_calls(&self) -> &Option { + &self.tool_calls + } + + pub fn text(&self) -> String { + match self.patched_text.clone() { + Some(text) => text, + None => self.text.clone(), + } + } + + pub fn clear_patch(&mut self) { + self.patched_text = None; + } + + pub fn set_text(&mut self, text: String) { + self.text = text; + } + + pub fn stream(&self) -> bool { + self.config.read().stream && !self.role().model().no_stream() + } + + pub fn continue_output(&self) -> Option<&str> { + self.continue_output.as_deref() + } + + pub fn set_continue_output(&mut self, output: &str) { + let output = match &self.continue_output { + Some(v) => format!("{v}{output}"), + None => output.to_string(), + }; + self.continue_output = Some(output); + } + + pub fn regenerate(&self) -> bool { + self.regenerate + } + + pub fn set_regenerate(&mut self) { + let role = self.config.read().extract_role(); + if role.name() == self.role().name() { + self.role = role; + } + self.regenerate = true; + self.tool_calls = None; + } + + pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> { + if self.text.is_empty() { + return Ok(()); + } + let rag = self.config.read().rag.clone(); + if let Some(rag) = rag { + let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?; + self.patched_text = Some(result); + self.rag_name = Some(rag.name().to_string()); + } + Ok(()) + } + + pub fn rag_name(&self) -> Option<&str> { + self.rag_name.as_deref() + } + + pub fn merge_tool_results(mut self, output: String, tool_results: Vec) -> Self { + match self.tool_calls.as_mut() { + Some(exist_tool_results) => { + exist_tool_results.merge(tool_results, output); + } + None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)), + } + self + } + + pub fn create_client(&self) -> Result> { + init_client(&self.config, Some(self.role().model().clone())) + } + + pub async fn fetch_chat_text(&self) -> Result { + let client = self.create_client()?; + let text = client.chat_completions(self.clone()).await?.text; + let text = strip_think_tag(&text).to_string(); + Ok(text) + } + + pub fn prepare_completion_data( + &self, + model: &Model, + stream: bool, + ) -> Result { + let mut messages = self.build_messages()?; + patch_messages(&mut messages, model); + model.guard_max_input_tokens(&messages)?; + let (temperature, top_p) = (self.role().temperature(), self.role().top_p()); + let functions = self.config.read().select_functions(self.role()); + if let Some(vec) = &functions { + for def in vec { + debug!("Function definition: {:?}", def.name); + } + } + Ok(ChatCompletionsData { + messages, + temperature, + top_p, + functions, + stream, + }) + } + + pub fn build_messages(&self) -> Result> { + let mut messages = if let Some(session) = self.session(&self.config.read().session) { + session.build_messages(self) + } else { + self.role().build_messages(self) + }; + if let Some(tool_calls) = &self.tool_calls { + messages.push(Message::new( + MessageRole::Assistant, + MessageContent::ToolCalls(tool_calls.clone()), + )) + } + Ok(messages) + } + + pub fn echo_messages(&self) -> String { + if let Some(session) = self.session(&self.config.read().session) { + session.echo_messages(self) + } else { + self.role().echo_messages(self) + } + } + + pub fn role(&self) -> &Role { + &self.role + } + + pub fn session<'a>(&self, session: &'a Option) -> Option<&'a Session> { + if self.with_session { + session.as_ref() + } else { + None + } + } + + pub fn session_mut<'a>(&self, session: &'a mut Option) -> Option<&'a mut Session> { + if self.with_session { + session.as_mut() + } else { + None + } + } + + pub fn with_agent(&self) -> bool { + self.with_agent + } + + pub fn summary(&self) -> String { + let text: String = self + .text + .trim() + .chars() + .map(|c| if c.is_control() { ' ' } else { c }) + .collect(); + if text.width_cjk() > SUMMARY_MAX_WIDTH { + let mut sum_width = 0; + let mut chars = vec![]; + for c in text.chars() { + sum_width += c.width_cjk().unwrap_or(1); + if sum_width > SUMMARY_MAX_WIDTH - 3 { + chars.extend(['.', '.', '.']); + break; + } + chars.push(c); + } + chars.into_iter().collect() + } else { + text + } + } + + pub fn raw(&self) -> String { + let (text, files) = &self.raw; + let mut segments = files.to_vec(); + if !segments.is_empty() { + segments.insert(0, ".file".into()); + } + if !text.is_empty() { + if !segments.is_empty() { + segments.push("--".into()); + } + segments.push(text.clone()); + } + segments.join(" ") + } + + pub fn render(&self) -> String { + let text = self.text(); + if self.medias.is_empty() { + return text; + } + let tail_text = if text.is_empty() { + String::new() + } else { + format!(" -- {text}") + }; + let files: Vec = self + .medias + .iter() + .cloned() + .map(|url| resolve_data_url(&self.data_urls, url)) + .collect(); + format!(".file {}{}", files.join(" "), tail_text) + } + + pub fn message_content(&self) -> MessageContent { + if self.medias.is_empty() { + MessageContent::Text(self.text()) + } else { + let mut list: Vec = self + .medias + .iter() + .cloned() + .map(|url| MessageContentPart::ImageUrl { + image_url: ImageUrl { url }, + }) + .collect(); + if !self.text.is_empty() { + list.insert(0, MessageContentPart::Text { text: self.text() }); + } + MessageContent::Array(list) + } + } +} + +fn resolve_role(config: &Config, role: Option) -> (Role, bool, bool) { + match role { + Some(v) => (v, false, false), + None => ( + config.extract_role(), + config.session.is_some(), + config.agent.is_some(), + ), + } +} + +type ResolvePathsOutput = ( + Vec, + Vec, + Vec, + Vec, + Vec, + bool, +); + +fn resolve_paths( + loaders: &HashMap, + paths: Vec, +) -> Result { + let mut raw_paths = IndexSet::new(); + let mut local_paths = IndexSet::new(); + let mut remote_urls = IndexSet::new(); + let mut external_cmds = IndexSet::new(); + let mut protocol_paths = IndexSet::new(); + let mut with_last_reply = false; + for path in paths { + if path == "%%" { + with_last_reply = true; + raw_paths.insert(path); + } else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') { + external_cmds.insert(path[1..path.len() - 1].to_string()); + raw_paths.insert(path); + } else if is_url(&path) { + if path.strip_suffix("**").is_some() { + bail!("Invalid website '{path}'"); + } + remote_urls.insert(path.clone()); + raw_paths.insert(path); + } else if is_loader_protocol(loaders, &path) { + protocol_paths.insert(path.clone()); + raw_paths.insert(path); + } else { + let resolved_path = resolve_home_dir(&path); + let absolute_path = to_absolute_path(&resolved_path) + .with_context(|| format!("Invalid path '{path}'"))?; + local_paths.insert(resolved_path); + raw_paths.insert(absolute_path); + } + } + Ok(( + raw_paths.into_iter().collect(), + local_paths.into_iter().collect(), + remote_urls.into_iter().collect(), + external_cmds.into_iter().collect(), + protocol_paths.into_iter().collect(), + with_last_reply, + )) +} + +async fn load_documents( + loaders: &HashMap, + local_paths: Vec, + remote_urls: Vec, + external_cmds: Vec, + protocol_paths: Vec, +) -> Result<( + Vec<(&'static str, String, String)>, + Vec, + HashMap, +)> { + let mut files = vec![]; + let mut medias = vec![]; + let mut data_urls = HashMap::new(); + + for cmd in external_cmds { + let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd]) + .stderr_to_stdout() + .unchecked() + .read() + .unwrap_or_else(|err| err.to_string()); + files.push(("CMD", cmd, output)); + } + + let local_files = expand_glob_paths(&local_paths, true).await?; + for file_path in local_files { + if is_image(&file_path) { + let contents = read_media_to_data_url(&file_path) + .with_context(|| format!("Unable to read media '{file_path}'"))?; + data_urls.insert(sha256(&contents), file_path); + medias.push(contents) + } else { + let document = load_file(loaders, &file_path) + .await + .with_context(|| format!("Unable to read file '{file_path}'"))?; + files.push(("FILE", file_path, document.contents)); + } + } + + for file_url in remote_urls { + let (contents, extension) = fetch_with_loaders(loaders, &file_url, true) + .await + .with_context(|| format!("Failed to load url '{file_url}'"))?; + if extension == MEDIA_URL_EXTENSION { + data_urls.insert(sha256(&contents), file_url); + medias.push(contents) + } else { + files.push(("URL", file_url, contents)); + } + } + + for protocol_path in protocol_paths { + let documents = load_protocol_path(loaders, &protocol_path) + .with_context(|| format!("Failed to load from '{protocol_path}'"))?; + files.extend( + documents + .into_iter() + .map(|document| ("FROM", document.path, document.contents)), + ); + } + + Ok((files, medias, data_urls)) +} + +pub fn resolve_data_url(data_urls: &HashMap, data_url: String) -> String { + if data_url.starts_with("data:") { + let hash = sha256(&data_url); + if let Some(path) = data_urls.get(&hash) { + return path.to_string(); + } + data_url + } else { + data_url + } +} + +fn is_image(path: &str) -> bool { + get_patch_extension(path) + .map(|v| IMAGE_EXTS.contains(&v.as_str())) + .unwrap_or_default() +} + +fn read_media_to_data_url(image_path: &str) -> Result { + let extension = get_patch_extension(image_path).unwrap_or_default(); + let mime_type = match extension.as_str() { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "webp" => "image/webp", + "gif" => "image/gif", + _ => bail!("Unexpected media type"), + }; + let mut file = File::open(image_path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + + let encoded_image = base64_encode(buffer); + let data_url = format!("data:{mime_type};base64,{encoded_image}"); + + Ok(data_url) +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..db73c60 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,3034 @@ +mod agent; +mod input; +mod role; +mod session; + +pub use self::agent::{complete_agent_variables, list_agents, Agent, AgentVariables}; +pub use self::input::Input; +pub use self::role::{ + Role, RoleLike, CODE_ROLE, CREATE_TITLE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, +}; +use self::session::Session; +use mem::take; + +use crate::client::{ + create_client_config, list_client_types, list_models, ClientConfig, MessageContentToolCalls, + Model, ModelType, ProviderModels, OPENAI_COMPATIBLE_PROVIDERS, +}; +use crate::function::{FunctionDeclaration, Functions, ToolResult}; +use crate::rag::Rag; +use crate::render::{MarkdownRender, RenderOptions}; +use crate::repl::{run_repl_command, split_args_text}; +use crate::utils::*; + +use crate::mcp::{ + McpRegistry, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX, +}; +use anyhow::{anyhow, bail, Context, Result}; +use indexmap::IndexMap; +use inquire::{list_option::ListOption, validator::Validation, Confirm, MultiSelect, Select, Text}; +use log::LevelFilter; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::{HashMap, HashSet}; +use std::{ + env, + fs::{ + create_dir_all, read_dir, read_to_string, remove_dir_all, remove_file, File, OpenOptions, + }, + io::Write, + mem, + path::{Path, PathBuf}, + process, + sync::{Arc, OnceLock}, +}; +use syntect::highlighting::ThemeSet; +use terminal_colorsaurus::{color_scheme, ColorScheme, QueryOptions}; +use tokio::runtime::Handle; + +pub const TEMP_ROLE_NAME: &str = "temp"; +pub const TEMP_RAG_NAME: &str = "temp"; +pub const TEMP_SESSION_NAME: &str = "temp"; + +/// Monokai Extended +const DARK_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin"); +const LIGHT_THEME: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin"); + +const CONFIG_FILE_NAME: &str = "config.yaml"; +const ROLES_DIR_NAME: &str = "roles"; +const MACROS_DIR_NAME: &str = "macros"; +const ENV_FILE_NAME: &str = ".env"; +const MESSAGES_FILE_NAME: &str = "messages.md"; +const SESSIONS_DIR_NAME: &str = "sessions"; +const RAGS_DIR_NAME: &str = "rags"; +const FUNCTIONS_DIR_NAME: &str = "functions"; +const FUNCTIONS_BIN_DIR_NAME: &str = "bin"; +const AGENTS_DIR_NAME: &str = "agents"; +const GLOBAL_TOOLS_DIR_NAME: &str = "tools"; +const GLOBAL_TOOLS_FILE_NAME: &str = "tools.txt"; +const MCP_FILE_NAME: &str = "mcp.json"; + +const CLIENTS_FIELD: &str = "clients"; + +const SERVE_ADDR: &str = "127.0.0.1:8000"; + +const SYNC_MODELS_URL: &str = + "https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml"; + +const SUMMARIZE_PROMPT: &str = + "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."; +const SUMMARY_PROMPT: &str = "This is a summary of the chat history as a recap: "; + +const RAG_TEMPLATE: &str = r#"Answer the query based on the context while respecting the rules. (user query, some textual context and rules, all inside xml tags) + + +__CONTEXT__ + + + +- If you don't know, just say so. +- If you are not sure, ask for clarification. +- Answer in the same language as the user query. +- If the context appears unreadable or of poor quality, tell the user then answer as best as you can. +- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge. +- Answer directly and without using xml tags. + + + +__INPUT__ +"#; + +const LEFT_PROMPT: &str = "{color.green}{?session {?agent {agent}>}{session}{?role /}}{!session {?agent {agent}>}}{role}{?rag @{rag}}{color.cyan}{?session )}{!session >}{color.reset} "; +const RIGHT_PROMPT: &str = "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"; + +static EDITOR: OnceLock> = OnceLock::new(); + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct Config { + #[serde(rename(serialize = "model", deserialize = "model"))] + #[serde(default)] + pub model_id: String, + pub temperature: Option, + pub top_p: Option, + + pub dry_run: bool, + pub stream: bool, + pub save: bool, + pub keybindings: String, + pub editor: Option, + pub wrap: Option, + pub wrap_code: bool, + + pub function_calling: bool, + pub mapping_tools: IndexMap, + pub use_tools: Option, + + pub mcp_servers: bool, + pub mapping_mcp_servers: IndexMap, + pub use_mcp_servers: Option, + + pub repl_prelude: Option, + pub cmd_prelude: Option, + pub agent_prelude: Option, + + pub save_session: Option, + pub compress_threshold: usize, + pub summarize_prompt: Option, + pub summary_prompt: Option, + + pub rag_embedding_model: Option, + pub rag_reranker_model: Option, + pub rag_top_k: usize, + pub rag_chunk_size: Option, + pub rag_chunk_overlap: Option, + pub rag_template: Option, + + #[serde(default)] + pub document_loaders: HashMap, + + pub highlight: bool, + pub theme: Option, + pub left_prompt: Option, + pub right_prompt: Option, + + pub serve_addr: Option, + pub user_agent: Option, + pub save_shell_history: bool, + pub sync_models_url: Option, + + pub clients: Vec, + + #[serde(skip)] + pub macro_flag: bool, + #[serde(skip)] + pub info_flag: bool, + #[serde(skip)] + pub agent_variables: Option, + + #[serde(skip)] + pub model: Model, + #[serde(skip)] + pub functions: Functions, + #[serde(skip)] + pub mcp_registry: Option, + #[serde(skip)] + pub working_mode: WorkingMode, + #[serde(skip)] + pub last_message: Option, + + #[serde(skip)] + pub role: Option, + #[serde(skip)] + pub session: Option, + #[serde(skip)] + pub rag: Option>, + #[serde(skip)] + pub agent: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + model_id: Default::default(), + temperature: None, + top_p: None, + + dry_run: false, + stream: true, + save: false, + keybindings: "emacs".into(), + editor: None, + wrap: None, + wrap_code: false, + + function_calling: true, + mapping_tools: Default::default(), + use_tools: None, + + mcp_servers: true, + mapping_mcp_servers: Default::default(), + use_mcp_servers: None, + + repl_prelude: None, + cmd_prelude: None, + agent_prelude: None, + + save_session: None, + compress_threshold: 4000, + summarize_prompt: None, + summary_prompt: None, + + rag_embedding_model: None, + rag_reranker_model: None, + rag_top_k: 5, + rag_chunk_size: None, + rag_chunk_overlap: None, + rag_template: None, + + document_loaders: Default::default(), + + highlight: true, + theme: None, + left_prompt: None, + right_prompt: None, + + serve_addr: None, + user_agent: None, + save_shell_history: true, + sync_models_url: None, + + clients: vec![], + + macro_flag: false, + info_flag: false, + agent_variables: None, + + model: Default::default(), + functions: Default::default(), + mcp_registry: Default::default(), + working_mode: WorkingMode::Cmd, + last_message: None, + + role: None, + session: None, + rag: None, + agent: None, + } + } +} + +pub type GlobalConfig = Arc>; + +impl Config { + pub fn init_bare() -> Result { + let h = Handle::current(); + tokio::task::block_in_place(|| { + h.block_on(Self::init( + WorkingMode::Cmd, + true, + false, + None, + create_abort_signal(), + )) + }) + } + + pub async fn init( + working_mode: WorkingMode, + info_flag: bool, + start_mcp_servers: bool, + log_path: Option, + abort_signal: AbortSignal, + ) -> Result { + let config_path = Self::config_file(); + let mut config = if !config_path.exists() { + match env::var(get_env_name("provider")) + .ok() + .or_else(|| env::var(get_env_name("platform")).ok()) + { + Some(v) => Self::load_dynamic(&v)?, + None => { + if *IS_STDOUT_TERMINAL { + create_config_file(&config_path).await?; + } + Self::load_from_file(&config_path)? + } + } + } else { + Self::load_from_file(&config_path)? + }; + + config.working_mode = working_mode; + config.info_flag = info_flag; + + let setup = async |config: &mut Self| -> Result<()> { + config.load_envs(); + + if let Some(wrap) = config.wrap.clone() { + config.set_wrap(&wrap)?; + } + + config.load_functions()?; + config + .load_mcp_servers(log_path, start_mcp_servers, abort_signal) + .await?; + + config.setup_model()?; + config.setup_document_loaders(); + config.setup_user_agent(); + Ok(()) + }; + let ret = setup(&mut config).await; + if !info_flag { + ret?; + } + Ok(config) + } + + pub fn config_dir() -> PathBuf { + if let Ok(v) = env::var(get_env_name("config_dir")) { + PathBuf::from(v) + } else if let Ok(v) = env::var("XDG_CONFIG_HOME") { + PathBuf::from(v).join(env!("CARGO_CRATE_NAME")) + } else { + let dir = dirs::config_dir().expect("No user's config directory"); + dir.join(env!("CARGO_CRATE_NAME")) + } + } + + pub fn local_path(name: &str) -> PathBuf { + Self::config_dir().join(name) + } + + pub fn cache_path() -> PathBuf { + let base_dir = dirs::cache_dir().unwrap_or_else(env::temp_dir); + + base_dir.join(env!("CARGO_CRATE_NAME")) + } + + pub fn log_path() -> PathBuf { + Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME"))) + } + + pub fn config_file() -> PathBuf { + match env::var(get_env_name("config_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(CONFIG_FILE_NAME), + } + } + + pub fn roles_dir() -> PathBuf { + match env::var(get_env_name("roles_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(ROLES_DIR_NAME), + } + } + + pub fn role_file(name: &str) -> PathBuf { + Self::roles_dir().join(format!("{name}.md")) + } + + pub fn macros_dir() -> PathBuf { + match env::var(get_env_name("macros_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(MACROS_DIR_NAME), + } + } + + pub fn macro_file(name: &str) -> PathBuf { + Self::macros_dir().join(format!("{name}.yaml")) + } + + pub fn env_file() -> PathBuf { + match env::var(get_env_name("env_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(ENV_FILE_NAME), + } + } + + pub fn messages_file(&self) -> PathBuf { + match &self.agent { + None => match env::var(get_env_name("messages_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::cache_path().join(MESSAGES_FILE_NAME), + }, + Some(agent) => Self::cache_path() + .join(AGENTS_DIR_NAME) + .join(agent.name()) + .join(MESSAGES_FILE_NAME), + } + } + + pub fn sessions_dir(&self) -> PathBuf { + match &self.agent { + None => match env::var(get_env_name("sessions_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(SESSIONS_DIR_NAME), + }, + Some(agent) => Self::agent_data_dir(agent.name()).join(SESSIONS_DIR_NAME), + } + } + + pub fn rags_dir() -> PathBuf { + match env::var(get_env_name("rags_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(RAGS_DIR_NAME), + } + } + + pub fn functions_dir() -> PathBuf { + match env::var(get_env_name("functions_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(FUNCTIONS_DIR_NAME), + } + } + + pub fn functions_bin_dir() -> PathBuf { + Self::functions_dir().join(FUNCTIONS_BIN_DIR_NAME) + } + + pub fn mcp_config_file() -> PathBuf { + Self::functions_dir().join(MCP_FILE_NAME) + } + + pub fn global_tools_file() -> PathBuf { + Self::functions_dir().join(GLOBAL_TOOLS_FILE_NAME) + } + + pub fn global_tools_dir() -> PathBuf { + Self::functions_dir().join(GLOBAL_TOOLS_DIR_NAME) + } + + pub fn session_file(&self, name: &str) -> PathBuf { + match name.split_once("/") { + Some((dir, name)) => self.sessions_dir().join(dir).join(format!("{name}.yaml")), + None => self.sessions_dir().join(format!("{name}.yaml")), + } + } + + pub fn rag_file(&self, name: &str) -> PathBuf { + match &self.agent { + Some(agent) => Self::agent_rag_file(agent.name(), name), + None => Self::rags_dir().join(format!("{name}.yaml")), + } + } + + pub fn agents_data_dir() -> PathBuf { + Self::local_path(AGENTS_DIR_NAME) + } + + pub fn agent_data_dir(name: &str) -> PathBuf { + match env::var(format!("{}_DATA_DIR", normalize_env_name(name))) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::agents_data_dir().join(name), + } + } + + pub fn agent_config_file(name: &str) -> PathBuf { + match env::var(format!("{}_CONFIG_FILE", normalize_env_name(name))) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::agent_data_dir(name).join(CONFIG_FILE_NAME), + } + } + + pub fn agent_bin_dir(name: &str) -> PathBuf { + Self::agent_data_dir(name).join(FUNCTIONS_BIN_DIR_NAME) + } + + pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> PathBuf { + Self::agent_data_dir(agent_name).join(format!("{rag_name}.yaml")) + } + + pub fn agent_functions_file(name: &str) -> Result { + let allowed = ["tools.sh", "tools.py", "tools.js"]; + + for entry in read_dir(Self::agent_data_dir(name))? { + let entry = entry?; + if let Some(file) = entry.file_name().to_str() { + if allowed.contains(&file) { + return Ok(entry.path()); + } + } + } + + Err(anyhow!( + "No tools script found in agent functions directory" + )) + } + + pub fn models_override_file() -> PathBuf { + Self::local_path("models-override.yaml") + } + + pub fn state(&self) -> StateFlags { + let mut flags = StateFlags::empty(); + if let Some(session) = &self.session { + if session.is_empty() { + flags |= StateFlags::SESSION_EMPTY; + } else { + flags |= StateFlags::SESSION; + } + if session.role_name().is_some() { + flags |= StateFlags::ROLE; + } + } else if self.role.is_some() { + flags |= StateFlags::ROLE; + } + if self.agent.is_some() { + flags |= StateFlags::AGENT; + } + if self.rag.is_some() { + flags |= StateFlags::RAG; + } + flags + } + + pub fn serve_addr(&self) -> String { + self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into()) + } + + pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option)> { + let log_level = env::var(get_env_name("log_level")) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(match cfg!(debug_assertions) { + true => LevelFilter::Debug, + false => { + if is_serve { + LevelFilter::Off + } else { + LevelFilter::Info + } + } + }); + let log_path = match env::var(get_env_name("log_path")) { + Ok(v) => Some(PathBuf::from(v)), + Err(_) => match is_serve { + true => None, + false => Some(Config::log_path()), + }, + }; + Ok((log_level, log_path)) + } + + pub fn edit_config(&self) -> Result<()> { + let config_path = Self::config_file(); + let editor = self.editor()?; + edit_file(&editor, &config_path)?; + println!( + "NOTE: Remember to restart {} if there are changes made to '{}'", + env!("CARGO_CRATE_NAME"), + config_path.display(), + ); + Ok(()) + } + + pub fn current_model(&self) -> &Model { + if let Some(session) = self.session.as_ref() { + session.model() + } else if let Some(agent) = self.agent.as_ref() { + agent.model() + } else if let Some(role) = self.role.as_ref() { + role.model() + } else { + &self.model + } + } + + pub fn role_like_mut(&mut self) -> Option<&mut dyn RoleLike> { + if let Some(session) = self.session.as_mut() { + Some(session) + } else if let Some(agent) = self.agent.as_mut() { + Some(agent) + } else if let Some(role) = self.role.as_mut() { + Some(role) + } else { + None + } + } + + pub fn extract_role(&self) -> Role { + if let Some(session) = self.session.as_ref() { + session.to_role() + } else if let Some(agent) = self.agent.as_ref() { + agent.to_role() + } else if let Some(role) = self.role.as_ref() { + role.clone() + } else { + let mut role = Role::default(); + role.batch_set( + &self.model, + self.temperature, + self.top_p, + self.use_tools.clone(), + self.use_mcp_servers.clone(), + ); + role + } + } + + pub fn info(&self) -> Result { + if let Some(agent) = &self.agent { + let output = agent.export()?; + if let Some(session) = &self.session { + let session = session + .export()? + .split('\n') + .map(|v| format!(" {v}")) + .collect::>() + .join("\n"); + Ok(format!("{output}session:\n{session}")) + } else { + Ok(output) + } + } else if let Some(session) = &self.session { + session.export() + } else if let Some(role) = &self.role { + Ok(role.export()) + } else if let Some(rag) = &self.rag { + rag.export() + } else { + self.sysinfo() + } + } + + pub fn sysinfo(&self) -> Result { + let display_path = |path: &Path| path.display().to_string(); + let wrap = self + .wrap + .clone() + .map_or_else(|| String::from("no"), |v| v.to_string()); + let (rag_reranker_model, rag_top_k) = match &self.rag { + Some(rag) => rag.get_config(), + None => (self.rag_reranker_model.clone(), self.rag_top_k), + }; + let role = self.extract_role(); + let mut items = vec![ + ("model", role.model().id()), + ("temperature", format_option_value(&role.temperature())), + ("top_p", format_option_value(&role.top_p())), + ("use_tools", format_option_value(&role.use_tools())), + ( + "use_mcp_servers", + format_option_value(&role.use_mcp_servers()), + ), + ( + "max_output_tokens", + role.model() + .max_tokens_param() + .map(|v| format!("{v} (current model)")) + .unwrap_or_else(|| "null".into()), + ), + ("save_session", format_option_value(&self.save_session)), + ("compress_threshold", self.compress_threshold.to_string()), + ( + "rag_reranker_model", + format_option_value(&rag_reranker_model), + ), + ("rag_top_k", rag_top_k.to_string()), + ("dry_run", self.dry_run.to_string()), + ("function_calling", self.function_calling.to_string()), + ("mcp_servers", self.mcp_servers.to_string()), + ("stream", self.stream.to_string()), + ("save", self.save.to_string()), + ("keybindings", self.keybindings.clone()), + ("wrap", wrap), + ("wrap_code", self.wrap_code.to_string()), + ("highlight", self.highlight.to_string()), + ("theme", format_option_value(&self.theme)), + ("config_file", display_path(&Self::config_file())), + ("env_file", display_path(&Self::env_file())), + ("roles_dir", display_path(&Self::roles_dir())), + ("sessions_dir", display_path(&self.sessions_dir())), + ("rags_dir", display_path(&Self::rags_dir())), + ("macros_dir", display_path(&Self::macros_dir())), + ("functions_dir", display_path(&Self::functions_dir())), + ("messages_file", display_path(&self.messages_file())), + ]; + if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) { + items.push(("log_path", display_path(&log_path))); + } + let output = items + .iter() + .map(|(name, value)| format!("{name:<24}{value}\n")) + .collect::>() + .join(""); + Ok(output) + } + + pub fn update(config: &GlobalConfig, data: &str) -> Result<()> { + let parts: Vec<&str> = data.split_whitespace().collect(); + if parts.len() != 2 { + bail!("Usage: .set . If value is null, unset key."); + } + let key = parts[0]; + let value = parts[1]; + match key { + "temperature" => { + let value = parse_value(value)?; + config.write().set_temperature(value); + } + "top_p" => { + let value = parse_value(value)?; + config.write().set_top_p(value); + } + "use_tools" => { + let value = parse_value(value)?; + config.write().set_use_tools(value); + } + "use_mcp_servers" => { + let value = parse_value(value)?; + config.write().set_use_mcp_servers(value); + } + "max_output_tokens" => { + let value = parse_value(value)?; + config.write().set_max_output_tokens(value); + } + "save_session" => { + let value = parse_value(value)?; + config.write().set_save_session(value); + } + "compress_threshold" => { + let value = parse_value(value)?; + config.write().set_compress_threshold(value); + } + "rag_reranker_model" => { + let value = parse_value(value)?; + Self::set_rag_reranker_model(config, value)?; + } + "rag_top_k" => { + let value = value.parse().with_context(|| "Invalid value")?; + Self::set_rag_top_k(config, value)?; + } + "dry_run" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().dry_run = value; + } + "function_calling" => { + let value = value.parse().with_context(|| "Invalid value")?; + if value && config.write().functions.is_empty() { + bail!("Function calling cannot be enabled because no functions are installed.") + } + config.write().function_calling = value; + } + "mcp_servers" => { + let value = value.parse().with_context(|| "Invalid value")?; + if value && !config.write().functions.has_mcp_functions() { + bail!("MCP servers cannot be enabled because no MCP servers are installed.") + } + config.write().mcp_servers = value; + } + "stream" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().stream = value; + } + "save" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().save = value; + } + "highlight" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().highlight = value; + } + _ => bail!("Unknown key '{key}'"), + } + Ok(()) + } + + pub fn delete(config: &GlobalConfig, kind: &str) -> Result<()> { + let (dir, file_ext) = match kind { + "role" => (Self::roles_dir(), Some(".md")), + "session" => (config.read().sessions_dir(), Some(".yaml")), + "rag" => (Self::rags_dir(), Some(".yaml")), + "macro" => (Self::macros_dir(), Some(".yaml")), + "agent-data" => (Self::agents_data_dir(), None), + _ => bail!("Unknown kind '{kind}'"), + }; + let names = match read_dir(&dir) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd.flatten() { + let name = entry.file_name(); + match file_ext { + Some(file_ext) => { + if let Some(name) = name.to_string_lossy().strip_suffix(file_ext) { + names.push(name.to_string()); + } + } + None => { + if entry.path().is_dir() { + names.push(name.to_string_lossy().to_string()); + } + } + } + } + names.sort_unstable(); + names + } + Err(_) => vec![], + }; + + if names.is_empty() { + bail!("No {kind} to delete") + } + + let select_names = MultiSelect::new(&format!("Select {kind} to delete:"), names) + .with_validator(|list: &[ListOption<&String>]| { + if list.is_empty() { + Ok(Validation::Invalid( + "At least one item must be selected".into(), + )) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + + for name in select_names { + match file_ext { + Some(ext) => { + let path = dir.join(format!("{name}{ext}")); + remove_file(&path).with_context(|| { + format!("Failed to delete {kind} at '{}'", path.display()) + })?; + } + None => { + let path = dir.join(name); + remove_dir_all(&path).with_context(|| { + format!("Failed to delete {kind} at '{}'", path.display()) + })?; + } + } + } + println!("✓ Successfully deleted {kind}."); + Ok(()) + } + + pub fn set_temperature(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_temperature(value), + None => self.temperature = value, + } + } + + pub fn set_top_p(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_top_p(value), + None => self.top_p = value, + } + } + + pub fn set_use_tools(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_use_tools(value), + None => self.use_tools = value, + } + } + + pub fn set_use_mcp_servers(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_use_mcp_servers(value), + None => self.use_mcp_servers = value, + } + } + + pub fn set_save_session(&mut self, value: Option) { + if let Some(session) = self.session.as_mut() { + session.set_save_session(value); + } else { + self.save_session = value; + } + } + + pub fn set_compress_threshold(&mut self, value: Option) { + if let Some(session) = self.session.as_mut() { + session.set_compress_threshold(value); + } else { + self.compress_threshold = value.unwrap_or_default(); + } + } + + pub fn set_rag_reranker_model(config: &GlobalConfig, value: Option) -> Result<()> { + if let Some(id) = &value { + Model::retrieve_model(&config.read(), id, ModelType::Reranker)?; + } + let has_rag = config.read().rag.is_some(); + match has_rag { + true => update_rag(config, |rag| { + rag.set_reranker_model(value)?; + Ok(()) + })?, + false => config.write().rag_reranker_model = value, + } + Ok(()) + } + + pub fn set_rag_top_k(config: &GlobalConfig, value: usize) -> Result<()> { + let has_rag = config.read().rag.is_some(); + match has_rag { + true => update_rag(config, |rag| { + rag.set_top_k(value)?; + Ok(()) + })?, + false => config.write().rag_top_k = value, + } + Ok(()) + } + + pub fn set_wrap(&mut self, value: &str) -> Result<()> { + if value == "no" { + self.wrap = None; + } else if value == "auto" { + self.wrap = Some(value.into()); + } else { + value + .parse::() + .map_err(|_| anyhow!("Invalid wrap value"))?; + self.wrap = Some(value.into()) + } + Ok(()) + } + + pub fn set_max_output_tokens(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => { + let mut model = role_like.model().clone(); + model.set_max_tokens(value, true); + role_like.set_model(model); + } + None => { + self.model.set_max_tokens(value, true); + } + }; + } + + pub fn set_model(&mut self, model_id: &str) -> Result<()> { + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; + match self.role_like_mut() { + Some(role_like) => role_like.set_model(model), + None => { + self.model = model; + } + } + Ok(()) + } + + pub fn use_prompt(&mut self, prompt: &str) -> Result<()> { + let mut role = Role::new(TEMP_ROLE_NAME, prompt); + role.set_model(self.current_model().clone()); + self.use_role_obj(role) + } + + pub async fn use_role_safely( + config: &GlobalConfig, + name: &str, + abort_signal: AbortSignal, + ) -> Result<()> { + let mut cfg = { + let mut guard = config.write(); + take(&mut *guard) + }; + + cfg.use_role(name, abort_signal.clone()).await?; + + { + let mut guard = config.write(); + *guard = cfg; + } + + Ok(()) + } + + pub async fn use_role(&mut self, name: &str, abort_signal: AbortSignal) -> Result<()> { + let role = self.retrieve_role(name)?; + self.functions.clear_mcp_meta_functions(); + let mcp_servers = role.use_mcp_servers(); + let registry = self + .mcp_registry + .take() + .expect("MCP registry should be initialized"); + let new_mcp_registry = + McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; + + if !new_mcp_registry.is_empty() { + self.functions + .append_mcp_meta_functions(new_mcp_registry.list_servers()); + } + + self.mcp_registry = Some(new_mcp_registry); + self.use_role_obj(role) + } + + pub fn use_role_obj(&mut self, role: Role) -> Result<()> { + if self.agent.is_some() { + bail!("Cannot perform this operation because you are using a agent") + } + if let Some(session) = self.session.as_mut() { + session.guard_empty()?; + session.set_role(role); + } else { + self.role = Some(role); + } + Ok(()) + } + + pub fn role_info(&self) -> Result { + if let Some(session) = &self.session { + if session.role_name().is_some() { + let role = session.to_role(); + Ok(role.export()) + } else { + bail!("No session role") + } + } else if let Some(role) = &self.role { + Ok(role.export()) + } else { + bail!("No role") + } + } + + pub fn exit_role(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + session.guard_empty()?; + session.clear_role(); + } else if self.role.is_some() { + self.role = None; + } + Ok(()) + } + + pub fn retrieve_role(&self, name: &str) -> Result { + let names = Self::list_roles(false); + let mut role = if names.contains(&name.to_string()) { + let path = Self::role_file(name); + let content = read_to_string(&path)?; + Role::new(name, &content) + } else { + Role::builtin(name)? + }; + let current_model = self.current_model().clone(); + match role.model_id() { + Some(model_id) => { + if current_model.id() != model_id { + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; + role.set_model(model); + } else { + role.set_model(current_model); + } + } + None => { + role.set_model(current_model); + if role.temperature().is_none() { + role.set_temperature(self.temperature); + } + if role.top_p().is_none() { + role.set_top_p(self.top_p); + } + } + } + Ok(role) + } + + pub fn new_role(&mut self, name: &str) -> Result<()> { + if self.macro_flag { + bail!("No role"); + } + let ans = Confirm::new("Create a new role?") + .with_default(true) + .prompt()?; + if ans { + self.upsert_role(name)?; + } else { + bail!("No role"); + } + Ok(()) + } + + pub async fn edit_role(&mut self, abort_signal: AbortSignal) -> Result<()> { + let role_name; + if let Some(session) = self.session.as_ref() { + if let Some(name) = session.role_name().map(|v| v.to_string()) { + if session.is_empty() { + role_name = Some(name); + } else { + bail!("Cannot perform this operation because you are in a non-empty session") + } + } else { + bail!("No role") + } + } else { + role_name = self.role.as_ref().map(|v| v.name().to_string()); + } + let name = role_name.ok_or_else(|| anyhow!("No role"))?; + self.upsert_role(&name)?; + self.use_role(&name, abort_signal.clone()).await + } + + pub fn upsert_role(&mut self, name: &str) -> Result<()> { + let role_path = Self::role_file(name); + ensure_parent_exists(&role_path)?; + let editor = self.editor()?; + edit_file(&editor, &role_path)?; + if self.working_mode.is_repl() { + println!("✓ Saved the role to '{}'.", role_path.display()); + } + Ok(()) + } + + pub fn save_role(&mut self, name: Option<&str>) -> Result<()> { + let mut role_name = match &self.role { + Some(role) => { + if role.has_args() { + bail!("Unable to save the role with arguments (whose name contains '#')") + } + match name { + Some(v) => v.to_string(), + None => role.name().to_string(), + } + } + None => bail!("No role"), + }; + if role_name == TEMP_ROLE_NAME { + role_name = Text::new("Role name:") + .with_validator(|input: &str| { + let input = input.trim(); + if input.is_empty() { + Ok(Validation::Invalid("This name is required".into())) + } else if input == TEMP_ROLE_NAME { + Ok(Validation::Invalid("This name is reserved".into())) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + } + let role_path = Self::role_file(&role_name); + if let Some(role) = self.role.as_mut() { + role.save(&role_name, &role_path, self.working_mode.is_repl())?; + } + + Ok(()) + } + + pub fn all_roles() -> Vec { + let mut roles: HashMap = Role::list_builtin_roles() + .iter() + .map(|v| (v.name().to_string(), v.clone())) + .collect(); + let names = Self::list_roles(false); + for name in names { + if let Ok(content) = read_to_string(Self::role_file(&name)) { + let role = Role::new(&name, &content); + roles.insert(name, role); + } + } + let mut roles: Vec<_> = roles.into_values().collect(); + roles.sort_unstable_by(|a, b| a.name().cmp(b.name())); + roles + } + + pub fn list_roles(with_builtin: bool) -> Vec { + let mut names = HashSet::new(); + if let Ok(rd) = read_dir(Self::roles_dir()) { + for entry in rd.flatten() { + if let Some(name) = entry + .file_name() + .to_str() + .and_then(|v| v.strip_suffix(".md")) + { + names.insert(name.to_string()); + } + } + } + if with_builtin { + names.extend(Role::list_builtin_role_names()); + } + let mut names: Vec<_> = names.into_iter().collect(); + names.sort_unstable(); + names + } + + pub fn has_role(name: &str) -> bool { + let names = Self::list_roles(true); + names.contains(&name.to_string()) + } + + pub fn use_session(&mut self, session_name: Option<&str>) -> Result<()> { + if self.session.is_some() { + bail!( + "Already in a session, please run '.exit session' first to exit the current session." + ); + } + let mut session; + match session_name { + None | Some(TEMP_SESSION_NAME) => { + let session_file = self.session_file(TEMP_SESSION_NAME); + if session_file.exists() { + remove_file(session_file).with_context(|| { + format!("Failed to cleanup previous '{TEMP_SESSION_NAME}' session") + })?; + } + session = Some(Session::new(self, TEMP_SESSION_NAME)); + } + Some(name) => { + let session_path = self.session_file(name); + if !session_path.exists() { + session = Some(Session::new(self, name)); + } else { + session = Some(Session::load(self, name, &session_path)?); + } + } + } + let mut new_session = false; + if let Some(session) = session.as_mut() { + if session.is_empty() { + new_session = true; + if let Some(LastMessage { + input, + output, + continuous, + }) = &self.last_message + { + if (*continuous && !output.is_empty()) + && self.agent.is_some() == input.with_agent() + { + let ans = Confirm::new( + "Start a session that incorporates the last question and answer?", + ) + .with_default(false) + .prompt()?; + if ans { + session.add_message(input, output)?; + } + } + } + } + } + self.session = session; + self.init_agent_session_variables(new_session)?; + Ok(()) + } + + pub fn session_info(&self) -> Result { + if let Some(session) = &self.session { + let render_options = self.render_options()?; + let mut markdown_render = MarkdownRender::init(render_options)?; + let agent_info: Option<(String, Vec)> = self.agent.as_ref().map(|agent| { + let functions = agent + .functions() + .declarations() + .iter() + .filter_map(|v| if v.agent { Some(v.name.clone()) } else { None }) + .collect(); + (agent.name().to_string(), functions) + }); + session.render(&mut markdown_render, &agent_info) + } else { + bail!("No session") + } + } + + pub fn exit_session(&mut self) -> Result<()> { + if let Some(mut session) = self.session.take() { + let sessions_dir = self.sessions_dir(); + session.exit(&sessions_dir, self.working_mode.is_repl())?; + self.discontinuous_last_message(); + } + Ok(()) + } + + pub fn save_session(&mut self, name: Option<&str>) -> Result<()> { + let session_name = match &self.session { + Some(session) => match name { + Some(v) => v.to_string(), + None => session + .autoname() + .unwrap_or_else(|| session.name()) + .to_string(), + }, + None => bail!("No session"), + }; + let session_path = self.session_file(&session_name); + if let Some(session) = self.session.as_mut() { + session.save(&session_name, &session_path, self.working_mode.is_repl())?; + } + Ok(()) + } + + pub fn edit_session(&mut self) -> Result<()> { + let name = match &self.session { + Some(session) => session.name().to_string(), + None => bail!("No session"), + }; + let session_path = self.session_file(&name); + self.save_session(Some(&name))?; + let editor = self.editor()?; + edit_file(&editor, &session_path).with_context(|| { + format!( + "Failed to edit '{}' with '{editor}'", + session_path.display() + ) + })?; + self.session = Some(Session::load(self, &name, &session_path)?); + self.discontinuous_last_message(); + Ok(()) + } + + pub fn empty_session(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + if let Some(agent) = self.agent.as_ref() { + session.sync_agent(agent); + } + session.clear_messages(); + } else { + bail!("No session") + } + self.discontinuous_last_message(); + Ok(()) + } + + pub fn set_save_session_this_time(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + session.set_save_session_this_time(); + } else { + bail!("No session") + } + Ok(()) + } + + pub fn list_sessions(&self) -> Vec { + list_file_names(self.sessions_dir(), ".yaml") + } + + pub fn list_autoname_sessions(&self) -> Vec { + list_file_names(self.sessions_dir().join("_"), ".yaml") + } + + pub fn maybe_compress_session(config: GlobalConfig) { + let mut need_compress = false; + { + let mut config = config.write(); + let compress_threshold = config.compress_threshold; + if let Some(session) = config.session.as_mut() { + if session.need_compress(compress_threshold) { + session.set_compressing(true); + need_compress = true; + } + } + }; + if !need_compress { + return; + } + let color = if config.read().light_theme() { + nu_ansi_term::Color::LightGray + } else { + nu_ansi_term::Color::DarkGray + }; + print!( + "\n📢 {}\n", + color.italic().paint("Compressing the session."), + ); + tokio::spawn(async move { + if let Err(err) = Config::compress_session(&config).await { + warn!("Failed to compress the session: {err}"); + } + if let Some(session) = config.write().session.as_mut() { + session.set_compressing(false); + } + }); + } + + pub async fn compress_session(config: &GlobalConfig) -> Result<()> { + match config.read().session.as_ref() { + Some(session) => { + if !session.has_user_messages() { + bail!("No need to compress since there are no messages in the session") + } + } + None => bail!("No session"), + } + + let prompt = config + .read() + .summarize_prompt + .clone() + .unwrap_or_else(|| SUMMARIZE_PROMPT.into()); + let input = Input::from_str(config, &prompt, None); + let summary = input.fetch_chat_text().await?; + let summary_prompt = config + .read() + .summary_prompt + .clone() + .unwrap_or_else(|| SUMMARY_PROMPT.into()); + if let Some(session) = config.write().session.as_mut() { + session.compress(format!("{summary_prompt}{summary}")); + } + config.write().discontinuous_last_message(); + Ok(()) + } + + pub fn is_compressing_session(&self) -> bool { + self.session + .as_ref() + .map(|v| v.compressing()) + .unwrap_or_default() + } + + pub fn maybe_autoname_session(config: GlobalConfig) { + let mut need_autoname = false; + if let Some(session) = config.write().session.as_mut() { + if session.need_autoname() { + session.set_autonaming(true); + need_autoname = true; + } + } + if !need_autoname { + return; + } + let color = if config.read().light_theme() { + nu_ansi_term::Color::LightGray + } else { + nu_ansi_term::Color::DarkGray + }; + print!("\n📢 {}\n", color.italic().paint("Autonaming the session."),); + tokio::spawn(async move { + if let Err(err) = Config::autoname_session(&config).await { + warn!("Failed to autonaming the session: {err}"); + } + if let Some(session) = config.write().session.as_mut() { + session.set_autonaming(false); + } + }); + } + + pub async fn autoname_session(config: &GlobalConfig) -> Result<()> { + let text = match config + .read() + .session + .as_ref() + .and_then(|v| v.chat_history_for_autonaming()) + { + Some(v) => v, + None => bail!("No chat history"), + }; + let role = config.read().retrieve_role(CREATE_TITLE_ROLE)?; + let input = Input::from_str(config, &text, Some(role)); + let text = input.fetch_chat_text().await?; + if let Some(session) = config.write().session.as_mut() { + session.set_autoname(&text); + } + Ok(()) + } + + pub async fn use_rag( + config: &GlobalConfig, + rag: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + if config.read().agent.is_some() { + bail!("Cannot perform this operation because you are using a agent") + } + let rag = match rag { + None => { + let rag_path = config.read().rag_file(TEMP_RAG_NAME); + if rag_path.exists() { + remove_file(&rag_path).with_context(|| { + format!("Failed to cleanup previous '{TEMP_RAG_NAME}' rag") + })?; + } + Rag::init(config, TEMP_RAG_NAME, &rag_path, &[], abort_signal).await? + } + Some(name) => { + let rag_path = config.read().rag_file(name); + if !rag_path.exists() { + if config.read().working_mode.is_cmd() { + bail!("Unknown RAG '{name}'") + } + Rag::init(config, name, &rag_path, &[], abort_signal).await? + } else { + Rag::load(config, name, &rag_path)? + } + } + }; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } + + pub async fn edit_rag_docs(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; + + let document_paths = rag.document_paths(); + let temp_file = temp_file(&format!("-rag-{}", rag.name()), ".txt"); + tokio::fs::write(&temp_file, &document_paths.join("\n")) + .await + .with_context(|| format!("Failed to write to '{}'", temp_file.display()))?; + let editor = config.read().editor()?; + edit_file(&editor, &temp_file)?; + let new_document_paths = tokio::fs::read_to_string(&temp_file) + .await + .with_context(|| format!("Failed to read '{}'", temp_file.display()))?; + let new_document_paths = new_document_paths + .split('\n') + .filter_map(|v| { + let v = v.trim(); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + }) + .collect::>(); + if new_document_paths.is_empty() || new_document_paths == document_paths { + bail!("No changes") + } + rag.refresh_document_paths(&new_document_paths, false, config, abort_signal) + .await?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } + + pub async fn rebuild_rag(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; + let document_paths = rag.document_paths().to_vec(); + rag.refresh_document_paths(&document_paths, true, config, abort_signal) + .await?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } + + pub fn rag_sources(config: &GlobalConfig) -> Result { + match config.read().rag.as_ref() { + Some(rag) => match rag.get_last_sources() { + Some(v) => Ok(v), + None => bail!("No sources"), + }, + None => bail!("No RAG"), + } + } + + pub fn rag_info(&self) -> Result { + if let Some(rag) = &self.rag { + rag.export() + } else { + bail!("No RAG") + } + } + + pub fn exit_rag(&mut self) -> Result<()> { + self.rag.take(); + Ok(()) + } + + pub async fn search_rag( + config: &GlobalConfig, + rag: &Rag, + text: &str, + abort_signal: AbortSignal, + ) -> Result { + let (reranker_model, top_k) = rag.get_config(); + let (embeddings, ids) = rag + .search(text, top_k, reranker_model.as_deref(), abort_signal) + .await?; + let text = config.read().rag_template(&embeddings, text); + rag.set_last_sources(&ids); + Ok(text) + } + + pub fn list_rags() -> Vec { + match read_dir(Self::rags_dir()) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd.flatten() { + let name = entry.file_name(); + if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { + names.push(name.to_string()); + } + } + names.sort_unstable(); + names + } + Err(_) => vec![], + } + } + + pub fn rag_template(&self, embeddings: &str, text: &str) -> String { + if embeddings.is_empty() { + return text.to_string(); + } + self.rag_template + .as_deref() + .unwrap_or(RAG_TEMPLATE) + .replace("__CONTEXT__", embeddings) + .replace("__INPUT__", text) + } + + pub async fn use_agent( + config: &GlobalConfig, + agent_name: &str, + session_name: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + if !config.read().function_calling { + bail!("Please enable function calling before using the agent."); + } + if config.read().agent.is_some() { + bail!("Already in an agent, please run '.exit agent' first to exit the current agent."); + } + let agent = Agent::init(config, agent_name, abort_signal).await?; + let session = session_name.map(|v| v.to_string()).or_else(|| { + if config.read().macro_flag { + None + } else { + agent.agent_prelude().map(|v| v.to_string()) + } + }); + config.write().rag = agent.rag(); + config.write().agent = Some(agent); + if let Some(session) = session { + config.write().use_session(Some(&session))?; + } else { + config.write().init_agent_shared_variables()?; + } + Ok(()) + } + + pub fn agent_info(&self) -> Result { + if let Some(agent) = &self.agent { + agent.export() + } else { + bail!("No agent") + } + } + + pub fn agent_banner(&self) -> Result { + if let Some(agent) = &self.agent { + Ok(agent.banner()) + } else { + bail!("No agent") + } + } + + pub fn edit_agent_config(&self) -> Result<()> { + let agent_name = match &self.agent { + Some(agent) => agent.name(), + None => bail!("No agent"), + }; + let agent_config_path = Config::agent_config_file(agent_name); + ensure_parent_exists(&agent_config_path)?; + if !agent_config_path.exists() { + std::fs::write( + &agent_config_path, + "# see https://github.com/Dark-Alex-17/loki/blob/main/config.agent.example.yaml\n", + ) + .with_context(|| format!("Failed to write to '{}'", agent_config_path.display()))?; + } + let editor = self.editor()?; + edit_file(&editor, &agent_config_path)?; + println!( + "NOTE: Remember to reload the agent if there are changes made to '{}'", + agent_config_path.display() + ); + Ok(()) + } + + pub fn exit_agent(&mut self) -> Result<()> { + self.exit_session()?; + self.load_functions()?; + if self.agent.take().is_some() { + self.rag.take(); + self.discontinuous_last_message(); + } + Ok(()) + } + + pub fn exit_agent_session(&mut self) -> Result<()> { + self.exit_session()?; + if let Some(agent) = self.agent.as_mut() { + agent.exit_session(); + if self.working_mode.is_repl() { + self.init_agent_shared_variables()?; + } + } + Ok(()) + } + + pub fn list_macros() -> Vec { + list_file_names(Self::macros_dir(), ".yaml") + } + + pub fn load_macro(name: &str) -> Result { + let path = Self::macro_file(name); + let err = || format!("Failed to load macro '{name}' at '{}'", path.display()); + let content = read_to_string(&path).with_context(err)?; + let value: Macro = serde_yaml::from_str(&content).with_context(err)?; + Ok(value) + } + + pub fn has_macro(name: &str) -> bool { + let names = Self::list_macros(); + names.contains(&name.to_string()) + } + + pub fn new_macro(&mut self, name: &str) -> Result<()> { + if self.macro_flag { + bail!("No macro"); + } + let ans = Confirm::new("Create a new macro?") + .with_default(true) + .prompt()?; + if ans { + let macro_path = Self::macro_file(name); + ensure_parent_exists(¯o_path)?; + let editor = self.editor()?; + edit_file(&editor, ¯o_path)?; + } else { + bail!("No macro"); + } + Ok(()) + } + + pub async fn apply_prelude(&mut self, abort_signal: AbortSignal) -> Result<()> { + if self.macro_flag || !self.state().is_empty() { + return Ok(()); + } + let prelude = match self.working_mode { + WorkingMode::Repl => self.repl_prelude.as_ref(), + WorkingMode::Cmd => self.cmd_prelude.as_ref(), + WorkingMode::Serve => return Ok(()), + }; + let prelude = match prelude { + Some(v) => { + if v.is_empty() { + return Ok(()); + } + v.to_string() + } + None => return Ok(()), + }; + + let err_msg = || format!("Invalid prelude '{prelude}"); + match prelude.split_once(':') { + Some(("role", name)) => { + self.use_role(name, abort_signal) + .await + .with_context(err_msg)?; + } + Some(("session", name)) => { + self.use_session(Some(name)).with_context(err_msg)?; + } + Some((session_name, role_name)) => { + self.use_session(Some(session_name)).with_context(err_msg)?; + if let Some(true) = self.session.as_ref().map(|v| v.is_empty()) { + self.use_role(role_name, abort_signal) + .await + .with_context(err_msg)?; + } + } + _ => { + bail!("{}", err_msg()) + } + } + Ok(()) + } + + pub fn select_functions(&self, role: &Role) -> Option> { + let mut functions = vec![]; + functions.extend(self.select_enabled_functions(role)); + functions.extend(self.select_enabled_mcp_servers(role)); + + if functions.is_empty() { + None + } else { + Some(functions) + } + } + + fn select_enabled_functions(&self, role: &Role) -> Vec { + let mut functions = vec![]; + if self.function_calling { + if let Some(use_tools) = role.use_tools() { + let mut tool_names: HashSet = Default::default(); + let declaration_names: HashSet = self + .functions + .declarations() + .iter() + .map(|v| v.name.to_string()) + .collect(); + if use_tools == "all" { + tool_names.extend(declaration_names); + } else { + for item in use_tools.split(',') { + let item = item.trim(); + if let Some(values) = self.mapping_tools.get(item) { + tool_names.extend( + values + .split(',') + .map(|v| v.to_string()) + .filter(|v| declaration_names.contains(v)), + ) + } else if declaration_names.contains(item) { + tool_names.insert(item.to_string()); + } + } + } + functions = self + .functions + .declarations() + .iter() + .filter_map(|v| { + if tool_names.contains(&v.name) { + Some(v.clone()) + } else { + None + } + }) + .collect(); + } + + if let Some(agent) = &self.agent { + let mut agent_functions = agent.functions().declarations().to_vec(); + let tool_names: HashSet = agent_functions + .iter() + .filter_map(|v| { + if v.agent { + None + } else { + Some(v.name.to_string()) + } + }) + .collect(); + agent_functions.extend( + functions + .into_iter() + .filter(|v| !tool_names.contains(&v.name)), + ); + functions = agent_functions; + } + } + + functions + } + + fn select_enabled_mcp_servers(&self, role: &Role) -> Vec { + let mut mcp_functions = vec![]; + if self.mcp_servers { + if let Some(use_mcp_servers) = role.use_mcp_servers() { + let mut server_names: HashSet = Default::default(); + let mcp_declaration_names: HashSet = self + .functions + .declarations() + .iter() + .filter(|v| { + v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + }) + .map(|v| v.name.to_string()) + .collect(); + if use_mcp_servers == "all" { + server_names.extend(mcp_declaration_names); + } else { + for item in use_mcp_servers.split(',') { + let item = item.trim(); + let item_invoke_name = + format!("{}_{item}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); + let item_list_name = + format!("{}_{item}", MCP_LIST_META_FUNCTION_NAME_PREFIX); + if let Some(values) = self.mapping_tools.get(item) { + server_names.extend( + values + .split(',') + .flat_map(|v| { + vec![ + format!( + "{}_{}", + MCP_INVOKE_META_FUNCTION_NAME_PREFIX, + v.to_string() + ), + format!( + "{}_{}", + MCP_LIST_META_FUNCTION_NAME_PREFIX, + v.to_string() + ), + ] + }) + .filter(|v| mcp_declaration_names.contains(v)), + ) + } else if mcp_declaration_names.contains(&item_invoke_name) { + server_names.insert(item_invoke_name); + server_names.insert(item_list_name); + } + } + } + mcp_functions = self + .functions + .declarations() + .iter() + .filter_map(|v| { + if server_names.contains(&v.name) { + Some(v.clone()) + } else { + None + } + }) + .collect(); + } + + if let Some(agent) = &self.agent { + let mut agent_functions = agent.functions().declarations().to_vec(); + let tool_names: HashSet = agent_functions + .iter() + .filter_map(|v| { + if v.agent { + None + } else { + Some(v.name.to_string()) + } + }) + .collect(); + agent_functions.extend( + mcp_functions + .into_iter() + .filter(|v| !tool_names.contains(&v.name)), + ); + mcp_functions = agent_functions; + } + } + + mcp_functions + } + + pub fn editor(&self) -> Result { + EDITOR.get_or_init(move || { + let editor = self.editor.clone() + .or_else(|| env::var("VISUAL").ok().or_else(|| env::var("EDITOR").ok())) + .unwrap_or_else(|| { + if cfg!(windows) { + "notepad".to_string() + } else { + "nano".to_string() + } + }); + which::which(&editor).ok().map(|_| editor) + }) + .clone() + .ok_or_else(|| anyhow!("Editor not found. Please add the `editor` configuration or set the $EDITOR or $VISUAL environment variable.")) + } + + pub fn repl_complete( + &self, + cmd: &str, + args: &[&str], + _line: &str, + ) -> Vec<(String, Option)> { + let mut values: Vec<(String, Option)> = vec![]; + let filter = args.last().unwrap_or(&""); + if args.len() == 1 { + values = match cmd { + ".role" => map_completion_values(Self::list_roles(true)), + ".model" => list_models(self, ModelType::Chat) + .into_iter() + .map(|v| (v.id(), Some(v.description()))) + .collect(), + ".session" => { + if args[0].starts_with("_/") { + map_completion_values( + self.list_autoname_sessions() + .iter() + .rev() + .map(|v| format!("_/{v}")) + .collect::>(), + ) + } else { + map_completion_values(self.list_sessions()) + } + } + ".rag" => map_completion_values(Self::list_rags()), + ".agent" => map_completion_values(list_agents()), + ".macro" => map_completion_values(Self::list_macros()), + ".starter" => match &self.agent { + Some(agent) => agent + .conversation_starters() + .iter() + .enumerate() + .map(|(i, v)| ((i + 1).to_string(), Some(v.to_string()))) + .collect(), + None => vec![], + }, + ".set" => { + let mut values = vec![ + "temperature", + "top_p", + "use_tools", + "use_mcp_servers", + "save_session", + "compress_threshold", + "rag_reranker_model", + "rag_top_k", + "max_output_tokens", + "dry_run", + "function_calling", + "mcp_servers", + "stream", + "save", + "highlight", + ]; + values.sort_unstable(); + values + .into_iter() + .map(|v| (format!("{v} "), None)) + .collect() + } + ".delete" => { + map_completion_values(vec!["role", "session", "rag", "macro", "agent-data"]) + } + _ => vec![], + }; + } else if cmd == ".set" && args.len() == 2 { + let candidates = match args[0] { + "max_output_tokens" => match self.current_model().max_output_tokens() { + Some(v) => vec![v.to_string()], + None => vec![], + }, + "dry_run" => complete_bool(self.dry_run), + "stream" => complete_bool(self.stream), + "save" => complete_bool(self.save), + "function_calling" => complete_bool(self.function_calling), + "use_tools" => { + let mut prefix = String::new(); + let mut ignores = HashSet::new(); + if let Some((v, _)) = args[1].rsplit_once(',') { + ignores = v.split(',').collect(); + prefix = format!("{v},"); + } + let mut values = vec![]; + if prefix.is_empty() { + values.push("all".to_string()); + } + values.extend(self.functions.declarations().iter().map(|v| v.name.clone())); + values.extend(self.mapping_tools.keys().map(|v| v.to_string())); + values + .into_iter() + .filter(|v| !ignores.contains(v.as_str())) + .map(|v| format!("{prefix}{v}")) + .collect() + } + "mcp_servers" => complete_bool(self.mcp_servers), + "use_mcp_servers" => { + let mut prefix = String::new(); + let mut ignores = HashSet::new(); + if let Some((v, _)) = args[1].rsplit_once(',') { + ignores = v.split(',').collect(); + prefix = format!("{v},"); + } + let mut values = vec![]; + if prefix.is_empty() { + values.push("all".to_string()); + } + values.extend( + self.functions + .declarations() + .iter() + .filter(|v| { + v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + }) + .map(|v| { + v.name + .strip_prefix( + format!("{MCP_LIST_META_FUNCTION_NAME_PREFIX}_").as_str(), + ) + .or_else(|| { + v.name.strip_prefix( + format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_") + .as_str(), + ) + }) + .unwrap() + .to_string() + }), + ); + values.extend(self.mapping_mcp_servers.keys().map(|v| v.to_string())); + values + .into_iter() + .filter(|v| !ignores.contains(v.as_str())) + .map(|v| format!("{prefix}{v}")) + .collect() + } + "save_session" => { + let save_session = if let Some(session) = &self.session { + session.save_session() + } else { + self.save_session + }; + complete_option_bool(save_session) + } + "rag_reranker_model" => list_models(self, ModelType::Reranker) + .iter() + .map(|v| v.id()) + .collect(), + "highlight" => complete_bool(self.highlight), + _ => vec![], + }; + values = candidates.into_iter().map(|v| (v, None)).collect(); + } else if cmd == ".agent" { + if args.len() == 2 { + let dir = Self::agent_data_dir(args[0]).join(SESSIONS_DIR_NAME); + values = list_file_names(dir, ".yaml") + .into_iter() + .map(|v| (v, None)) + .collect(); + } + values.extend(complete_agent_variables(args[0])); + }; + fuzzy_filter(values, |v| v.0.as_str(), filter) + } + + pub fn sync_models_url(&self) -> String { + self.sync_models_url + .clone() + .unwrap_or_else(|| SYNC_MODELS_URL.into()) + } + + pub async fn sync_models(url: &str, abort_signal: AbortSignal) -> Result<()> { + let content = abortable_run_with_spinner(fetch(url), "Fetching models.yaml", abort_signal) + .await + .with_context(|| format!("Failed to fetch '{url}'"))?; + println!("✓ Fetched '{url}'"); + let list = serde_yaml::from_str::>(&content) + .with_context(|| "Failed to parse models.yaml")?; + let models_override = ModelsOverride { + version: env!("CARGO_PKG_VERSION").to_string(), + list, + }; + let models_override_data = + serde_yaml::to_string(&models_override).with_context(|| "Failed to serde {}")?; + + let model_override_path = Self::models_override_file(); + ensure_parent_exists(&model_override_path)?; + std::fs::write(&model_override_path, models_override_data) + .with_context(|| format!("Failed to write to '{}'", model_override_path.display()))?; + println!("✓ Updated '{}'", model_override_path.display()); + Ok(()) + } + + pub fn local_models_override() -> Result> { + let model_override_path = Self::models_override_file(); + let err = || { + format!( + "Failed to load models at '{}'", + model_override_path.display() + ) + }; + let content = read_to_string(&model_override_path).with_context(err)?; + let models_override: ModelsOverride = serde_yaml::from_str(&content).with_context(err)?; + if models_override.version != env!("CARGO_PKG_VERSION") { + bail!("Incompatible version") + } + Ok(models_override.list) + } + + pub fn light_theme(&self) -> bool { + matches!(self.theme.as_deref(), Some("light")) + } + + pub fn render_options(&self) -> Result { + let theme = if self.highlight { + let theme_mode = if self.light_theme() { "light" } else { "dark" }; + let theme_filename = format!("{theme_mode}.tmTheme"); + let theme_path = Self::local_path(&theme_filename); + if theme_path.exists() { + let theme = ThemeSet::get_theme(&theme_path) + .with_context(|| format!("Invalid theme at '{}'", theme_path.display()))?; + Some(theme) + } else { + let theme = if self.light_theme() { + decode_bin(LIGHT_THEME).context("Invalid builtin light theme")? + } else { + decode_bin(DARK_THEME).context("Invalid builtin dark theme")? + }; + Some(theme) + } + } else { + None + }; + let wrap = if *IS_STDOUT_TERMINAL { + self.wrap.clone() + } else { + None + }; + let truecolor = matches!( + env::var("COLORTERM").as_ref().map(|v| v.as_str()), + Ok("truecolor") + ); + Ok(RenderOptions::new(theme, wrap, self.wrap_code, truecolor)) + } + + pub fn render_prompt_left(&self) -> String { + let variables = self.generate_prompt_context(); + let left_prompt = self.left_prompt.as_deref().unwrap_or(LEFT_PROMPT); + render_prompt(left_prompt, &variables) + } + + pub fn render_prompt_right(&self) -> String { + let variables = self.generate_prompt_context(); + let right_prompt = self.right_prompt.as_deref().unwrap_or(RIGHT_PROMPT); + render_prompt(right_prompt, &variables) + } + + pub fn print_markdown(&self, text: &str) -> Result<()> { + if *IS_STDOUT_TERMINAL { + let render_options = self.render_options()?; + let mut markdown_render = MarkdownRender::init(render_options)?; + println!("{}", markdown_render.render(text)); + } else { + println!("{text}"); + } + Ok(()) + } + + fn generate_prompt_context(&self) -> HashMap<&str, String> { + let mut output = HashMap::new(); + let role = self.extract_role(); + output.insert("model", role.model().id()); + output.insert("client_name", role.model().client_name().to_string()); + output.insert("model_name", role.model().name().to_string()); + output.insert( + "max_input_tokens", + role.model() + .max_input_tokens() + .unwrap_or_default() + .to_string(), + ); + if let Some(temperature) = role.temperature() { + if temperature != 0.0 { + output.insert("temperature", temperature.to_string()); + } + } + if let Some(top_p) = role.top_p() { + if top_p != 0.0 { + output.insert("top_p", top_p.to_string()); + } + } + if self.dry_run { + output.insert("dry_run", "true".to_string()); + } + if self.stream { + output.insert("stream", "true".to_string()); + } + if self.save { + output.insert("save", "true".to_string()); + } + if let Some(wrap) = &self.wrap { + if wrap != "no" { + output.insert("wrap", wrap.clone()); + } + } + if !role.is_derived() { + output.insert("role", role.name().to_string()); + } + if let Some(session) = &self.session { + output.insert("session", session.name().to_string()); + if let Some(autoname) = session.autoname() { + output.insert("session_autoname", autoname.to_string()); + } + output.insert("dirty", session.dirty().to_string()); + let (tokens, percent) = session.tokens_usage(); + output.insert("consume_tokens", tokens.to_string()); + output.insert("consume_percent", percent.to_string()); + output.insert("user_messages_len", session.user_messages_len().to_string()); + } + if let Some(rag) = &self.rag { + output.insert("rag", rag.name().to_string()); + } + if let Some(agent) = &self.agent { + output.insert("agent", agent.name().to_string()); + } + + if self.highlight { + output.insert("color.reset", "\u{1b}[0m".to_string()); + output.insert("color.black", "\u{1b}[30m".to_string()); + output.insert("color.dark_gray", "\u{1b}[90m".to_string()); + output.insert("color.red", "\u{1b}[31m".to_string()); + output.insert("color.light_red", "\u{1b}[91m".to_string()); + output.insert("color.green", "\u{1b}[32m".to_string()); + output.insert("color.light_green", "\u{1b}[92m".to_string()); + output.insert("color.yellow", "\u{1b}[33m".to_string()); + output.insert("color.light_yellow", "\u{1b}[93m".to_string()); + output.insert("color.blue", "\u{1b}[34m".to_string()); + output.insert("color.light_blue", "\u{1b}[94m".to_string()); + output.insert("color.purple", "\u{1b}[35m".to_string()); + output.insert("color.light_purple", "\u{1b}[95m".to_string()); + output.insert("color.magenta", "\u{1b}[35m".to_string()); + output.insert("color.light_magenta", "\u{1b}[95m".to_string()); + output.insert("color.cyan", "\u{1b}[36m".to_string()); + output.insert("color.light_cyan", "\u{1b}[96m".to_string()); + output.insert("color.white", "\u{1b}[37m".to_string()); + output.insert("color.light_gray", "\u{1b}[97m".to_string()); + } + + output + } + + pub fn before_chat_completion(&mut self, input: &Input) -> Result<()> { + self.last_message = Some(LastMessage::new(input.clone(), String::new())); + Ok(()) + } + + pub fn after_chat_completion( + &mut self, + input: &Input, + output: &str, + tool_results: &[ToolResult], + ) -> Result<()> { + if !tool_results.is_empty() { + return Ok(()); + } + self.last_message = Some(LastMessage::new(input.clone(), output.to_string())); + if !self.dry_run { + self.save_message(input, output)?; + } + Ok(()) + } + + fn discontinuous_last_message(&mut self) { + if let Some(last_message) = self.last_message.as_mut() { + last_message.continuous = false; + } + } + + fn save_message(&mut self, input: &Input, output: &str) -> Result<()> { + let mut input = input.clone(); + input.clear_patch(); + if let Some(session) = input.session_mut(&mut self.session) { + session.add_message(&input, output)?; + return Ok(()); + } + + if !self.save { + return Ok(()); + } + let mut file = self.open_message_file()?; + if output.is_empty() && input.tool_calls().is_none() { + return Ok(()); + } + let now = now(); + let summary = input.summary(); + let raw_input = input.raw(); + let scope = if self.agent.is_none() { + let role_name = if input.role().is_derived() { + None + } else { + Some(input.role().name()) + }; + match (role_name, input.rag_name()) { + (Some(role), Some(rag_name)) => format!(" ({role}#{rag_name})"), + (Some(role), _) => format!(" ({role})"), + (None, Some(rag_name)) => format!(" (#{rag_name})"), + _ => String::new(), + } + } else { + String::new() + }; + let tool_calls = match input.tool_calls() { + Some(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut lines = vec!["".to_string()]; + if !text.is_empty() { + lines.push(text.clone()); + } + lines.push(serde_json::to_string(&tool_results).unwrap_or_default()); + lines.push("\n".to_string()); + lines.join("\n") + } + None => String::new(), + }; + let output = format!( + "# CHAT: {summary} [{now}]{scope}\n{raw_input}\n--------\n{tool_calls}{output}\n--------\n\n", + ); + file.write_all(output.as_bytes()) + .with_context(|| "Failed to save message") + } + + fn init_agent_shared_variables(&mut self) -> Result<()> { + let agent = match self.agent.as_mut() { + Some(v) => v, + None => return Ok(()), + }; + if !agent.defined_variables().is_empty() && agent.shared_variables().is_empty() { + let new_variables = + Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; + agent.set_shared_variables(new_variables); + } + if !self.info_flag { + agent.update_shared_dynamic_instructions(false)?; + } + Ok(()) + } + + fn init_agent_session_variables(&mut self, new_session: bool) -> Result<()> { + let (agent, session) = match (self.agent.as_mut(), self.session.as_mut()) { + (Some(agent), Some(session)) => (agent, session), + _ => return Ok(()), + }; + if new_session { + let shared_variables = agent.shared_variables().clone(); + let session_variables = + if !agent.defined_variables().is_empty() && shared_variables.is_empty() { + let new_variables = + Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; + agent.set_shared_variables(new_variables.clone()); + new_variables + } else { + shared_variables + }; + agent.set_session_variables(session_variables); + if !self.info_flag { + agent.update_session_dynamic_instructions(None)?; + } + session.sync_agent(agent); + } else { + let variables = session.agent_variables(); + agent.set_session_variables(variables.clone()); + agent.update_session_dynamic_instructions(Some( + session.agent_instructions().to_string(), + ))?; + } + Ok(()) + } + + fn open_message_file(&self) -> Result { + let path = self.messages_file(); + ensure_parent_exists(&path)?; + OpenOptions::new() + .create(true) + .append(true) + .open(&path) + .with_context(|| format!("Failed to create/append {}", path.display())) + } + + fn load_from_file(config_path: &Path) -> Result { + let err = || format!("Failed to load config at '{}'", config_path.display()); + let content = read_to_string(config_path).with_context(err)?; + let config: Self = serde_yaml::from_str(&content) + .map_err(|err| { + let err_msg = err.to_string(); + let err_msg = if err_msg.starts_with(&format!("{CLIENTS_FIELD}: ")) { + // location is incorrect, get rid of it + err_msg + .split_once(" at line") + .map(|(v, _)| { + format!("{v} (Sorry for being unable to provide an exact location)") + }) + .unwrap_or_else(|| "clients: invalid value".into()) + } else { + err_msg + }; + anyhow!("{err_msg}") + }) + .with_context(err)?; + + Ok(config) + } + + fn load_dynamic(model_id: &str) -> Result { + let provider = match model_id.split_once(':') { + Some((v, _)) => v, + _ => model_id, + }; + let is_openai_compatible = OPENAI_COMPATIBLE_PROVIDERS + .into_iter() + .any(|(name, _)| provider == name); + let client = if is_openai_compatible { + json!({ "type": "openai-compatible", "name": provider }) + } else { + json!({ "type": provider }) + }; + let config = json!({ + "model": model_id.to_string(), + "save": false, + "clients": vec![client], + }); + let config = + serde_json::from_value(config).with_context(|| "Failed to load config from env")?; + Ok(config) + } + + fn load_envs(&mut self) { + if let Ok(v) = env::var(get_env_name("model")) { + self.model_id = v; + } + if let Some(v) = read_env_value::(&get_env_name("temperature")) { + self.temperature = v; + } + if let Some(v) = read_env_value::(&get_env_name("top_p")) { + self.top_p = v; + } + + if let Some(Some(v)) = read_env_bool(&get_env_name("dry_run")) { + self.dry_run = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("stream")) { + self.stream = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("save")) { + self.save = v; + } + if let Ok(v) = env::var(get_env_name("keybindings")) { + if v == "vi" { + self.keybindings = v; + } + } + if let Some(v) = read_env_value::(&get_env_name("editor")) { + self.editor = v; + } + if let Some(v) = read_env_value::(&get_env_name("wrap")) { + self.wrap = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("wrap_code")) { + self.wrap_code = v; + } + + if let Some(Some(v)) = read_env_bool(&get_env_name("function_calling")) { + self.function_calling = v; + } + if let Ok(v) = env::var(get_env_name("mapping_tools")) { + if let Ok(v) = serde_json::from_str(&v) { + self.mapping_tools = v; + } + } + if let Some(v) = read_env_value::(&get_env_name("use_tools")) { + self.use_tools = v; + } + + if let Some(v) = read_env_value::(&get_env_name("repl_prelude")) { + self.repl_prelude = v; + } + if let Some(v) = read_env_value::(&get_env_name("cmd_prelude")) { + self.cmd_prelude = v; + } + if let Some(v) = read_env_value::(&get_env_name("agent_prelude")) { + self.agent_prelude = v; + } + + if let Some(v) = read_env_bool(&get_env_name("save_session")) { + self.save_session = v; + } + if let Some(Some(v)) = read_env_value::(&get_env_name("compress_threshold")) { + self.compress_threshold = v; + } + if let Some(v) = read_env_value::(&get_env_name("summarize_prompt")) { + self.summarize_prompt = v; + } + if let Some(v) = read_env_value::(&get_env_name("summary_prompt")) { + self.summary_prompt = v; + } + + if let Some(v) = read_env_value::(&get_env_name("rag_embedding_model")) { + self.rag_embedding_model = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_reranker_model")) { + self.rag_reranker_model = v; + } + if let Some(Some(v)) = read_env_value::(&get_env_name("rag_top_k")) { + self.rag_top_k = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_size")) { + self.rag_chunk_size = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_overlap")) { + self.rag_chunk_overlap = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_template")) { + self.rag_template = v; + } + + if let Ok(v) = env::var(get_env_name("document_loaders")) { + if let Ok(v) = serde_json::from_str(&v) { + self.document_loaders = v; + } + } + + if let Some(Some(v)) = read_env_bool(&get_env_name("highlight")) { + self.highlight = v; + } + if *NO_COLOR { + self.highlight = false; + } + if self.highlight && self.theme.is_none() { + if let Some(v) = read_env_value::(&get_env_name("theme")) { + self.theme = v; + } else if *IS_STDOUT_TERMINAL { + if let Ok(color_scheme) = color_scheme(QueryOptions::default()) { + let theme = match color_scheme { + ColorScheme::Dark => "dark", + ColorScheme::Light => "light", + }; + self.theme = Some(theme.into()); + } + } + } + if let Some(v) = read_env_value::(&get_env_name("left_prompt")) { + self.left_prompt = v; + } + if let Some(v) = read_env_value::(&get_env_name("right_prompt")) { + self.right_prompt = v; + } + + if let Some(v) = read_env_value::(&get_env_name("serve_addr")) { + self.serve_addr = v; + } + if let Some(v) = read_env_value::(&get_env_name("user_agent")) { + self.user_agent = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("save_shell_history")) { + self.save_shell_history = v; + } + if let Some(v) = read_env_value::(&get_env_name("sync_models_url")) { + self.sync_models_url = v; + } + } + + fn load_functions(&mut self) -> Result<()> { + self.functions = Functions::init()?; + Ok(()) + } + + async fn load_mcp_servers( + &mut self, + log_path: Option, + start_mcp_servers: bool, + abort_signal: AbortSignal, + ) -> Result<()> { + if !self.mcp_servers { + return Ok(()); + } + + let mcp_registry = McpRegistry::init( + log_path, + start_mcp_servers, + self.use_mcp_servers.clone(), + abort_signal.clone(), + ) + .await?; + match mcp_registry.is_empty() { + false => { + self.functions + .append_mcp_meta_functions(mcp_registry.list_servers()); + } + _ => debug!( + "Skipping global MCP functions registration since start_mcp_servers was 'false'" + ), + } + self.mcp_registry = Some(mcp_registry); + + Ok(()) + } + + fn setup_model(&mut self) -> Result<()> { + let mut model_id = self.model_id.clone(); + if model_id.is_empty() { + let models = list_models(self, ModelType::Chat); + if models.is_empty() { + bail!("No available model"); + } + model_id = models[0].id() + } + self.set_model(&model_id)?; + self.model_id = model_id; + + Ok(()) + } + + fn setup_document_loaders(&mut self) { + [("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")] + .into_iter() + .for_each(|(k, v)| { + let (k, v) = (k.to_string(), v.to_string()); + self.document_loaders.entry(k).or_insert(v); + }); + } + + fn setup_user_agent(&mut self) { + if let Some("auto") = self.user_agent.as_deref() { + self.user_agent = Some(format!( + "{}/{}", + env!("CARGO_CRATE_NAME"), + env!("CARGO_PKG_VERSION") + )); + } + } +} + +pub fn load_env_file() -> Result<()> { + let env_file_path = Config::env_file(); + let contents = match read_to_string(&env_file_path) { + Ok(v) => v, + Err(_) => return Ok(()), + }; + debug!("Use env file '{}'", env_file_path.display()); + for line in contents.lines() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + continue; + } + if let Some((key, value)) = line.split_once('=') { + env::set_var(key.trim(), value.trim()); + } + } + Ok(()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum WorkingMode { + Cmd, + Repl, + Serve, +} + +impl WorkingMode { + pub fn is_cmd(&self) -> bool { + *self == WorkingMode::Cmd + } + pub fn is_repl(&self) -> bool { + *self == WorkingMode::Repl + } + pub fn is_serve(&self) -> bool { + *self == WorkingMode::Serve + } +} + +#[async_recursion::async_recursion] +pub async fn macro_execute( + config: &GlobalConfig, + name: &str, + args: Option<&str>, + abort_signal: AbortSignal, +) -> Result<()> { + let macro_value = Config::load_macro(name)?; + let (mut new_args, text) = split_args_text(args.unwrap_or_default(), cfg!(windows)); + if !text.is_empty() { + new_args.push(text.to_string()); + } + let variables = macro_value + .resolve_variables(&new_args) + .map_err(|err| anyhow!("{err}. Usage: {}", macro_value.usage(name)))?; + let role = config.read().extract_role(); + let mut config = config.read().clone(); + config.temperature = role.temperature(); + config.top_p = role.top_p(); + config.use_tools = role.use_tools().clone(); + config.use_mcp_servers = role.use_mcp_servers().clone(); + config.macro_flag = true; + config.model = role.model().clone(); + config.role = None; + config.session = None; + config.rag = None; + config.agent = None; + config.discontinuous_last_message(); + let config = Arc::new(RwLock::new(config)); + config.write().macro_flag = true; + for step in ¯o_value.steps { + let command = Macro::interpolate_command(step, &variables); + println!(">> {}", multiline_text(&command)); + run_repl_command(&config, abort_signal.clone(), &command).await?; + } + Ok(()) +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Macro { + #[serde(default)] + pub variables: Vec, + pub steps: Vec, +} + +impl Macro { + pub fn resolve_variables(&self, args: &[String]) -> Result> { + let mut output = IndexMap::new(); + for (i, variable) in self.variables.iter().enumerate() { + let value = if variable.rest && i == self.variables.len() - 1 { + if args.len() > i { + Some(args[i..].join(" ")) + } else { + variable.default.clone() + } + } else { + args.get(i) + .map(|v| v.to_string()) + .or_else(|| variable.default.clone()) + }; + let value = + value.ok_or_else(|| anyhow!("Missing value for variable '{}'", variable.name))?; + output.insert(variable.name.clone(), value); + } + Ok(output) + } + + pub fn usage(&self, name: &str) -> String { + let mut parts = vec![name.to_string()]; + for (i, variable) in self.variables.iter().enumerate() { + let part = match ( + variable.rest && i == self.variables.len() - 1, + variable.default.is_some(), + ) { + (true, true) => format!("[{}]...", variable.name), + (true, false) => format!("<{}>...", variable.name), + (false, true) => format!("[{}]", variable.name), + (false, false) => format!("<{}>", variable.name), + }; + parts.push(part); + } + parts.join(" ") + } + + pub fn interpolate_command(command: &str, variables: &IndexMap) -> String { + let mut output = command.to_string(); + for (key, value) in variables { + output = output.replace(&format!("{{{{{key}}}}}"), value); + } + output + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct MacroVariable { + pub name: String, + #[serde(default)] + pub rest: bool, + pub default: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelsOverride { + pub version: String, + pub list: Vec, +} + +#[derive(Debug, Clone)] +pub struct LastMessage { + pub input: Input, + pub output: String, + pub continuous: bool, +} + +impl LastMessage { + pub fn new(input: Input, output: String) -> Self { + Self { + input, + output, + continuous: true, + } + } +} + +bitflags::bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct StateFlags: u32 { + const ROLE = 1 << 0; + const SESSION_EMPTY = 1 << 1; + const SESSION = 1 << 2; + const RAG = 1 << 3; + const AGENT = 1 << 4; + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AssertState { + True(StateFlags), + False(StateFlags), + TrueFalse(StateFlags, StateFlags), + Equal(StateFlags), +} + +impl AssertState { + pub fn pass() -> Self { + AssertState::False(StateFlags::empty()) + } + + pub fn bare() -> Self { + AssertState::Equal(StateFlags::empty()) + } + + pub fn assert(self, flags: StateFlags) -> bool { + match self { + AssertState::True(true_flags) => true_flags & flags != StateFlags::empty(), + AssertState::False(false_flags) => false_flags & flags == StateFlags::empty(), + AssertState::TrueFalse(true_flags, false_flags) => { + (true_flags & flags != StateFlags::empty()) + && (false_flags & flags == StateFlags::empty()) + } + AssertState::Equal(check_flags) => check_flags == flags, + } + } +} + +async fn create_config_file(config_path: &Path) -> Result<()> { + let ans = Confirm::new("No config file, create a new one?") + .with_default(true) + .prompt()?; + if !ans { + process::exit(0); + } + + let client = Select::new("API Provider (required):", list_client_types()).prompt()?; + + let mut config = json!({}); + let (model, clients_config) = create_client_config(client).await?; + config["model"] = model.into(); + config[CLIENTS_FIELD] = clients_config; + + let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?; + let config_data = format!( + "# see https://github.com/Dark-Alex-17/loki/blob/main/config.example.yaml\n\n{config_data}" + ); + + ensure_parent_exists(config_path)?; + std::fs::write(config_path, config_data) + .with_context(|| format!("Failed to write to '{}'", config_path.display()))?; + #[cfg(unix)] + { + use std::os::unix::prelude::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(config_path, perms)?; + } + + println!("✓ Saved the config file to '{}'.\n", config_path.display()); + + Ok(()) +} + +pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> { + if path.exists() { + return Ok(()); + } + let parent = path + .parent() + .ok_or_else(|| anyhow!("Failed to write to '{}', No parent path", path.display()))?; + if !parent.exists() { + create_dir_all(parent).with_context(|| { + format!( + "Failed to write to '{}', Cannot create parent directory", + path.display() + ) + })?; + } + Ok(()) +} + +fn read_env_value(key: &str) -> Option> +where + T: std::str::FromStr, +{ + let value = env::var(key).ok()?; + let value = parse_value(&value).ok()?; + Some(value) +} + +fn parse_value(value: &str) -> Result> +where + T: std::str::FromStr, +{ + let value = if value == "null" { + None + } else { + let value = match value.parse() { + Ok(value) => value, + Err(_) => bail!("Invalid value '{}'", value), + }; + Some(value) + }; + Ok(value) +} + +fn read_env_bool(key: &str) -> Option> { + let value = env::var(key).ok()?; + Some(parse_bool(&value)) +} + +fn complete_bool(value: bool) -> Vec { + vec![(!value).to_string()] +} + +fn complete_option_bool(value: Option) -> Vec { + match value { + Some(true) => vec!["false".to_string(), "null".to_string()], + Some(false) => vec!["true".to_string(), "null".to_string()], + None => vec!["true".to_string(), "false".to_string()], + } +} + +fn map_completion_values(value: Vec) -> Vec<(String, Option)> { + value.into_iter().map(|v| (v.to_string(), None)).collect() +} + +fn update_rag(config: &GlobalConfig, f: F) -> Result<()> +where + F: FnOnce(&mut Rag) -> Result<()>, +{ + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; + f(&mut rag)?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) +} + +fn format_option_value(value: &Option) -> String +where + T: std::fmt::Display, +{ + match value { + Some(value) => value.to_string(), + None => "null".to_string(), + } +} diff --git a/src/config/role.rs b/src/config/role.rs new file mode 100644 index 0000000..ca496a4 --- /dev/null +++ b/src/config/role.rs @@ -0,0 +1,416 @@ +use super::*; + +use crate::client::{Message, MessageContent, MessageRole, Model}; + +use anyhow::Result; +use fancy_regex::Regex; +use rust_embed::Embed; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::sync::LazyLock; + +pub const SHELL_ROLE: &str = "shell"; +pub const EXPLAIN_SHELL_ROLE: &str = "explain-shell"; +pub const CODE_ROLE: &str = "code"; +pub const CREATE_TITLE_ROLE: &str = "create-title"; + +pub const INPUT_PLACEHOLDER: &str = "__INPUT__"; + +#[derive(Embed)] +#[folder = "assets/roles/"] +struct RolesAsset; + +static RE_METADATA: LazyLock = + LazyLock::new(|| Regex::new(r"(?s)-{3,}\s*(.*?)\s*-{3,}\s*(.*)").unwrap()); + +pub trait RoleLike { + fn to_role(&self) -> Role; + fn model(&self) -> &Model; + fn temperature(&self) -> Option; + fn top_p(&self) -> Option; + fn use_tools(&self) -> Option; + fn use_mcp_servers(&self) -> Option; + fn set_model(&mut self, model: Model); + fn set_temperature(&mut self, value: Option); + fn set_top_p(&mut self, value: Option); + fn set_use_tools(&mut self, value: Option); + fn set_use_mcp_servers(&mut self, value: Option); +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct Role { + name: String, + #[serde(default)] + prompt: String, + #[serde( + rename(serialize = "model", deserialize = "model"), + skip_serializing_if = "Option::is_none" + )] + model_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + use_tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + use_mcp_servers: Option, + + #[serde(skip)] + model: Model, +} + +impl Role { + pub fn new(name: &str, content: &str) -> Self { + let mut metadata = ""; + let mut prompt = content.trim(); + if let Ok(Some(caps)) = RE_METADATA.captures(content) { + if let (Some(metadata_value), Some(prompt_value)) = (caps.get(1), caps.get(2)) { + metadata = metadata_value.as_str().trim(); + prompt = prompt_value.as_str().trim(); + } + } + let mut prompt = prompt.to_string(); + interpolate_variables(&mut prompt); + let mut role = Self { + name: name.to_string(), + prompt, + ..Default::default() + }; + if !metadata.is_empty() { + if let Ok(value) = serde_yaml::from_str::(metadata) { + if let Some(value) = value.as_object() { + for (key, value) in value { + match key.as_str() { + "model" => role.model_id = value.as_str().map(|v| v.to_string()), + "temperature" => role.temperature = value.as_f64(), + "top_p" => role.top_p = value.as_f64(), + "use_tools" => role.use_tools = value.as_str().map(|v| v.to_string()), + "use_mcp_servers" => { + role.use_mcp_servers = value.as_str().map(|v| v.to_string()) + } + _ => (), + } + } + } + } + } + role + } + + pub fn builtin(name: &str) -> Result { + let content = RolesAsset::get(&format!("{name}.md")) + .ok_or_else(|| anyhow!("Unknown role `{name}`"))?; + let content = unsafe { std::str::from_utf8_unchecked(&content.data) }; + Ok(Role::new(name, content)) + } + + pub fn list_builtin_role_names() -> Vec { + RolesAsset::iter() + .filter_map(|v| v.strip_suffix(".md").map(|v| v.to_string())) + .collect() + } + + pub fn list_builtin_roles() -> Vec { + RolesAsset::iter() + .filter_map(|v| Role::builtin(&v).ok()) + .collect() + } + + pub fn has_args(&self) -> bool { + self.name.contains('#') + } + + pub fn export(&self) -> String { + let mut metadata = vec![]; + if let Some(model) = self.model_id() { + metadata.push(format!("model: {model}")); + } + if let Some(temperature) = self.temperature() { + metadata.push(format!("temperature: {temperature}")); + } + if let Some(top_p) = self.top_p() { + metadata.push(format!("top_p: {top_p}")); + } + if let Some(use_tools) = self.use_tools() { + metadata.push(format!("use_tools: {use_tools}")); + } + if let Some(use_mcp_servers) = self.use_mcp_servers() { + metadata.push(format!("use_mcp_servers: {use_mcp_servers}")); + } + if metadata.is_empty() { + format!("{}\n", self.prompt) + } else if self.prompt.is_empty() { + format!("---\n{}\n---\n", metadata.join("\n")) + } else { + format!("---\n{}\n---\n\n{}\n", metadata.join("\n"), self.prompt) + } + } + + pub fn save(&mut self, role_name: &str, role_path: &Path, is_repl: bool) -> Result<()> { + ensure_parent_exists(role_path)?; + + let content = self.export(); + std::fs::write(role_path, content).with_context(|| { + format!( + "Failed to write role {} to {}", + self.name, + role_path.display() + ) + })?; + + if is_repl { + println!("✓ Saved role to '{}'.", role_path.display()); + } + + if role_name != self.name { + self.name = role_name.to_string(); + } + + Ok(()) + } + + pub fn sync(&mut self, role_like: &T) { + let model = role_like.model(); + let temperature = role_like.temperature(); + let top_p = role_like.top_p(); + let use_tools = role_like.use_tools(); + let use_mcp_servers = role_like.use_mcp_servers(); + self.batch_set(model, temperature, top_p, use_tools, use_mcp_servers); + } + + pub fn batch_set( + &mut self, + model: &Model, + temperature: Option, + top_p: Option, + use_tools: Option, + use_mcp_servers: Option, + ) { + self.set_model(model.clone()); + if temperature.is_some() { + self.set_temperature(temperature); + } + if top_p.is_some() { + self.set_top_p(top_p); + } + if use_tools.is_some() { + self.set_use_tools(use_tools); + } + if use_mcp_servers.is_some() { + self.set_use_mcp_servers(use_mcp_servers); + } + } + + pub fn is_derived(&self) -> bool { + self.name.is_empty() + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn model_id(&self) -> Option<&str> { + self.model_id.as_deref() + } + + pub fn prompt(&self) -> &str { + &self.prompt + } + + pub fn is_empty_prompt(&self) -> bool { + self.prompt.is_empty() + } + + pub fn is_embedded_prompt(&self) -> bool { + self.prompt.contains(INPUT_PLACEHOLDER) + } + + pub fn echo_messages(&self, input: &Input) -> String { + let input_markdown = input.render(); + if self.is_empty_prompt() { + input_markdown + } else if self.is_embedded_prompt() { + self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown) + } else { + format!("{}\n\n{}", self.prompt, input_markdown) + } + } + + pub fn build_messages(&self, input: &Input) -> Vec { + let mut content = input.message_content(); + let mut messages = if self.is_empty_prompt() { + vec![Message::new(MessageRole::User, content)] + } else if self.is_embedded_prompt() { + content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v)); + vec![Message::new(MessageRole::User, content)] + } else { + let mut messages = vec![]; + let (system, cases) = parse_structure_prompt(&self.prompt); + if !system.is_empty() { + messages.push(Message::new( + MessageRole::System, + MessageContent::Text(system.to_string()), + )); + } + if !cases.is_empty() { + messages.extend(cases.into_iter().flat_map(|(i, o)| { + vec![ + Message::new(MessageRole::User, MessageContent::Text(i.to_string())), + Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())), + ] + })); + } + messages.push(Message::new(MessageRole::User, content)); + messages + }; + if let Some(text) = input.continue_output() { + messages.push(Message::new( + MessageRole::Assistant, + MessageContent::Text(text.into()), + )); + } + messages + } +} + +impl RoleLike for Role { + fn to_role(&self) -> Role { + self.clone() + } + + fn model(&self) -> &Model { + &self.model + } + + fn temperature(&self) -> Option { + self.temperature + } + + fn top_p(&self) -> Option { + self.top_p + } + + fn use_tools(&self) -> Option { + self.use_tools.clone() + } + + fn use_mcp_servers(&self) -> Option { + self.use_mcp_servers.clone() + } + + fn set_model(&mut self, model: Model) { + if !self.model().id().is_empty() { + self.model_id = Some(model.id().to_string()); + } + self.model = model; + } + + fn set_temperature(&mut self, value: Option) { + self.temperature = value; + } + + fn set_top_p(&mut self, value: Option) { + self.top_p = value; + } + + fn set_use_tools(&mut self, value: Option) { + self.use_tools = value; + } + + fn set_use_mcp_servers(&mut self, value: Option) { + self.use_mcp_servers = value; + } +} + +fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) { + let mut text = prompt; + let mut search_input = true; + let mut system = None; + let mut parts = vec![]; + loop { + let search = if search_input { + "### INPUT:" + } else { + "### OUTPUT:" + }; + match text.find(search) { + Some(idx) => { + if system.is_none() { + system = Some(&text[..idx]) + } else { + parts.push(&text[..idx]) + } + search_input = !search_input; + text = &text[(idx + search.len())..]; + } + None => { + if !text.is_empty() { + if system.is_none() { + system = Some(text) + } else { + parts.push(text) + } + } + break; + } + } + } + let parts_len = parts.len(); + if parts_len > 0 && parts_len % 2 == 0 { + let cases: Vec<(&str, &str)> = parts + .iter() + .step_by(2) + .zip(parts.iter().skip(1).step_by(2)) + .map(|(i, o)| (i.trim(), o.trim())) + .collect(); + let system = system.map(|v| v.trim()).unwrap_or_default(); + return (system, cases); + } + + (prompt, vec![]) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_structure_prompt1() { + let prompt = r#" +System message +### INPUT: +Input 1 +### OUTPUT: +Output 1 +"#; + assert_eq!( + parse_structure_prompt(prompt), + ("System message", vec![("Input 1", "Output 1")]) + ); + } + + #[test] + fn test_parse_structure_prompt2() { + let prompt = r#" +### INPUT: +Input 1 +### OUTPUT: +Output 1 +"#; + assert_eq!( + parse_structure_prompt(prompt), + ("", vec![("Input 1", "Output 1")]) + ); + } + + #[test] + fn test_parse_structure_prompt3() { + let prompt = r#" +System message +### INPUT: +Input 1 +"#; + assert_eq!(parse_structure_prompt(prompt), (prompt, vec![])); + } +} diff --git a/src/config/session.rs b/src/config/session.rs new file mode 100644 index 0000000..0190251 --- /dev/null +++ b/src/config/session.rs @@ -0,0 +1,659 @@ +use super::input::*; +use super::*; + +use crate::client::{Message, MessageContent, MessageRole}; +use crate::render::MarkdownRender; + +use anyhow::{bail, Context, Result}; +use fancy_regex::Regex; +use inquire::{validator::Validation, Confirm, Text}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; +use std::fs::{read_to_string, write}; +use std::path::Path; +use std::sync::LazyLock; + +static RE_AUTONAME_PREFIX: LazyLock = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap()); + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct Session { + #[serde(rename(serialize = "model", deserialize = "model"))] + model_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + use_tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + use_mcp_servers: Option, + #[serde(skip_serializing_if = "Option::is_none")] + save_session: Option, + #[serde(skip_serializing_if = "Option::is_none")] + compress_threshold: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + role_name: Option, + #[serde(default, skip_serializing_if = "IndexMap::is_empty")] + agent_variables: AgentVariables, + #[serde(default, skip_serializing_if = "String::is_empty")] + agent_instructions: String, + + #[serde(default, skip_serializing_if = "Vec::is_empty")] + compressed_messages: Vec, + messages: Vec, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + data_urls: HashMap, + + #[serde(skip)] + model: Model, + #[serde(skip)] + role_prompt: String, + #[serde(skip)] + name: String, + #[serde(skip)] + path: Option, + #[serde(skip)] + dirty: bool, + #[serde(skip)] + save_session_this_time: bool, + #[serde(skip)] + compressing: bool, + #[serde(skip)] + autoname: Option, + #[serde(skip)] + tokens: usize, +} + +impl Session { + pub fn new(config: &Config, name: &str) -> Self { + let role = config.extract_role(); + let mut session = Self { + name: name.to_string(), + save_session: config.save_session, + ..Default::default() + }; + session.set_role(role); + session.dirty = false; + session + } + + pub fn load(config: &Config, name: &str, path: &Path) -> Result { + let content = read_to_string(path) + .with_context(|| format!("Failed to load session {} at {}", name, path.display()))?; + let mut session: Self = + serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?; + + session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?; + + if let Some(autoname) = name.strip_prefix("_/") { + session.name = TEMP_SESSION_NAME.to_string(); + session.path = None; + if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) { + session.autoname = Some(AutoName::new(autoname[16..].to_string())); + } + } else { + session.name = name.to_string(); + session.path = Some(path.display().to_string()); + } + + if let Some(role_name) = &session.role_name { + if let Ok(role) = config.retrieve_role(role_name) { + session.role_prompt = role.prompt().to_string(); + } + } + + session.update_tokens(); + + Ok(session) + } + + pub fn is_empty(&self) -> bool { + self.messages.is_empty() && self.compressed_messages.is_empty() + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn role_name(&self) -> Option<&str> { + self.role_name.as_deref() + } + + pub fn dirty(&self) -> bool { + self.dirty + } + + pub fn save_session(&self) -> Option { + self.save_session + } + + pub fn tokens(&self) -> usize { + self.tokens + } + + pub fn update_tokens(&mut self) { + self.tokens = self.model().total_tokens(&self.messages); + } + + pub fn has_user_messages(&self) -> bool { + self.messages.iter().any(|v| v.role.is_user()) + } + + pub fn user_messages_len(&self) -> usize { + self.messages.iter().filter(|v| v.role.is_user()).count() + } + + pub fn export(&self) -> Result { + let mut data = json!({ + "path": self.path, + "model": self.model().id(), + }); + if let Some(temperature) = self.temperature() { + data["temperature"] = temperature.into(); + } + if let Some(top_p) = self.top_p() { + data["top_p"] = top_p.into(); + } + if let Some(use_tools) = self.use_tools() { + data["use_tools"] = use_tools.into(); + } + if let Some(use_mcp_servers) = self.use_mcp_servers() { + data["use_mcp_servers"] = use_mcp_servers.into(); + } + if let Some(save_session) = self.save_session() { + data["save_session"] = save_session.into(); + } + let (tokens, percent) = self.tokens_usage(); + data["total_tokens"] = tokens.into(); + if let Some(max_input_tokens) = self.model().max_input_tokens() { + data["max_input_tokens"] = max_input_tokens.into(); + } + if percent != 0.0 { + data["total/max"] = format!("{percent}%").into(); + } + data["messages"] = json!(self.messages); + + let output = serde_yaml::to_string(&data) + .with_context(|| format!("Unable to show info about session '{}'", &self.name))?; + Ok(output) + } + + pub fn render( + &self, + render: &mut MarkdownRender, + agent_info: &Option<(String, Vec)>, + ) -> Result { + let mut items = vec![]; + + if let Some(path) = &self.path { + items.push(("path", path.to_string())); + } + + if let Some(autoname) = self.autoname() { + items.push(("autoname", autoname.to_string())); + } + + items.push(("model", self.model().id())); + + if let Some(temperature) = self.temperature() { + items.push(("temperature", temperature.to_string())); + } + if let Some(top_p) = self.top_p() { + items.push(("top_p", top_p.to_string())); + } + + if let Some(use_tools) = self.use_tools() { + items.push(("use_tools", use_tools)); + } + + if let Some(use_mcp_servers) = self.use_mcp_servers() { + items.push(("use_mcp_servers", use_mcp_servers)); + } + + if let Some(save_session) = self.save_session() { + items.push(("save_session", save_session.to_string())); + } + + if let Some(compress_threshold) = self.compress_threshold { + items.push(("compress_threshold", compress_threshold.to_string())); + } + + if let Some(max_input_tokens) = self.model().max_input_tokens() { + items.push(("max_input_tokens", max_input_tokens.to_string())); + } + + let mut lines: Vec = items + .iter() + .map(|(name, value)| format!("{name:<20}{value}")) + .collect(); + + lines.push(String::new()); + + if !self.is_empty() { + let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string()); + + for message in &self.messages { + match message.role { + MessageRole::System => { + lines.push( + render + .render(&message.content.render_input(resolve_url_fn, agent_info)), + ); + } + MessageRole::Assistant => { + if let MessageContent::Text(text) = &message.content { + lines.push(render.render(text)); + } + lines.push("".into()); + } + MessageRole::User => { + lines.push(format!( + ">> {}", + message.content.render_input(resolve_url_fn, agent_info) + )); + } + MessageRole::Tool => { + lines.push(message.content.render_input(resolve_url_fn, agent_info)); + } + } + } + } + + Ok(lines.join("\n")) + } + + pub fn tokens_usage(&self) -> (usize, f32) { + let tokens = self.tokens(); + let max_input_tokens = self.model().max_input_tokens().unwrap_or_default(); + let percent = if max_input_tokens == 0 { + 0.0 + } else { + let percent = tokens as f32 / max_input_tokens as f32 * 100.0; + (percent * 100.0).round() / 100.0 + }; + (tokens, percent) + } + + pub fn set_role(&mut self, role: Role) { + self.model_id = role.model().id(); + self.temperature = role.temperature(); + self.top_p = role.top_p(); + self.use_tools = role.use_tools(); + self.use_mcp_servers = role.use_mcp_servers(); + self.model = role.model().clone(); + self.role_name = convert_option_string(role.name()); + self.role_prompt = role.prompt().to_string(); + self.dirty = true; + self.update_tokens(); + } + + pub fn clear_role(&mut self) { + self.role_name = None; + self.role_prompt.clear(); + } + + pub fn sync_agent(&mut self, agent: &Agent) { + self.role_name = None; + self.role_prompt = agent.interpolated_instructions(); + self.agent_variables = agent.variables().clone(); + self.agent_instructions = self.role_prompt.clone(); + } + + pub fn agent_variables(&self) -> &AgentVariables { + &self.agent_variables + } + + pub fn agent_instructions(&self) -> &str { + &self.agent_instructions + } + + pub fn set_save_session(&mut self, value: Option) { + if self.save_session != value { + self.save_session = value; + self.dirty = true; + } + } + + pub fn set_save_session_this_time(&mut self) { + self.save_session_this_time = true; + } + + pub fn set_compress_threshold(&mut self, value: Option) { + if self.compress_threshold != value { + self.compress_threshold = value; + self.dirty = true; + } + } + + pub fn need_compress(&self, global_compress_threshold: usize) -> bool { + if self.compressing { + return false; + } + let threshold = self.compress_threshold.unwrap_or(global_compress_threshold); + if threshold < 1 { + return false; + } + self.tokens() > threshold + } + + pub fn compressing(&self) -> bool { + self.compressing + } + + pub fn set_compressing(&mut self, compressing: bool) { + self.compressing = compressing; + } + + pub fn compress(&mut self, mut prompt: String) { + if let Some(system_prompt) = self.messages.first().and_then(|v| { + if MessageRole::System == v.role { + let content = v.content.to_text(); + if !content.is_empty() { + return Some(content); + } + } + None + }) { + prompt = format!("{system_prompt}\n\n{prompt}",); + } + self.compressed_messages.append(&mut self.messages); + self.messages.push(Message::new( + MessageRole::System, + MessageContent::Text(prompt), + )); + self.dirty = true; + self.update_tokens(); + } + + pub fn need_autoname(&self) -> bool { + self.autoname.as_ref().map(|v| v.need()).unwrap_or_default() + } + + pub fn set_autonaming(&mut self, naming: bool) { + if let Some(v) = self.autoname.as_mut() { + v.naming = naming; + } + } + + pub fn chat_history_for_autonaming(&self) -> Option { + self.autoname.as_ref().and_then(|v| v.chat_history.clone()) + } + + pub fn autoname(&self) -> Option<&str> { + self.autoname.as_ref().and_then(|v| v.name.as_deref()) + } + + pub fn set_autoname(&mut self, value: &str) { + let name = value + .chars() + .map(|v| if v.is_alphanumeric() { v } else { '-' }) + .collect(); + self.autoname = Some(AutoName::new(name)); + } + + pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> { + let mut save_session = self.save_session(); + if self.save_session_this_time { + save_session = Some(true); + } + if self.dirty && save_session != Some(false) { + let mut session_dir = session_dir.to_path_buf(); + let mut session_name = self.name().to_string(); + if save_session.is_none() { + if !is_repl { + return Ok(()); + } + let ans = Confirm::new("Save session?").with_default(false).prompt()?; + if !ans { + return Ok(()); + } + if session_name == TEMP_SESSION_NAME { + session_name = Text::new("Session name:") + .with_validator(|input: &str| { + let input = input.trim(); + if input.is_empty() { + Ok(Validation::Invalid("This name is required".into())) + } else if input == TEMP_SESSION_NAME { + Ok(Validation::Invalid("This name is reserved".into())) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + } + } else if save_session == Some(true) && session_name == TEMP_SESSION_NAME { + session_dir = session_dir.join("_"); + ensure_parent_exists(&session_dir).with_context(|| { + format!("Failed to create directory '{}'", session_dir.display()) + })?; + + let now = chrono::Local::now(); + session_name = now.format("%Y%m%dT%H%M%S").to_string(); + if let Some(autoname) = self.autoname() { + session_name = format!("{session_name}-{autoname}") + } + } + let session_path = session_dir.join(format!("{session_name}.yaml")); + self.save(&session_name, &session_path, is_repl)?; + } + Ok(()) + } + + pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> { + ensure_parent_exists(session_path)?; + + self.path = Some(session_path.display().to_string()); + + let content = serde_yaml::to_string(&self) + .with_context(|| format!("Failed to serde session '{}'", self.name))?; + write(session_path, content).with_context(|| { + format!( + "Failed to write session '{}' to '{}'", + self.name, + session_path.display() + ) + })?; + + if is_repl { + println!("✓ Saved the session to '{}'.", session_path.display()); + } + + if self.name() != session_name { + self.name = session_name.to_string() + } + + self.dirty = false; + + Ok(()) + } + + pub fn guard_empty(&self) -> Result<()> { + if !self.is_empty() { + bail!("Cannot perform this operation because the session has messages, please `.empty session` first."); + } + Ok(()) + } + + pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> { + if input.continue_output().is_some() { + if let Some(message) = self.messages.last_mut() { + if let MessageContent::Text(text) = &mut message.content { + *text = format!("{text}{output}"); + } + } + } else if input.regenerate() { + if let Some(message) = self.messages.last_mut() { + if let MessageContent::Text(text) = &mut message.content { + *text = output.to_string(); + } + } + } else { + if self.messages.is_empty() { + if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) { + let raw_input = input.raw(); + let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n"); + self.autoname = Some(AutoName::new_from_chat_history(chat_history)); + } + self.messages.extend(input.role().build_messages(input)); + } else { + self.messages + .push(Message::new(MessageRole::User, input.message_content())); + } + self.data_urls.extend(input.data_urls()); + if let Some(tool_calls) = input.tool_calls() { + self.messages.push(Message::new( + MessageRole::Tool, + MessageContent::ToolCalls(tool_calls.clone()), + )) + } + self.messages.push(Message::new( + MessageRole::Assistant, + MessageContent::Text(output.to_string()), + )); + } + self.dirty = true; + self.update_tokens(); + Ok(()) + } + + pub fn clear_messages(&mut self) { + self.messages.clear(); + self.compressed_messages.clear(); + self.data_urls.clear(); + self.autoname = None; + self.dirty = true; + self.update_tokens(); + } + + pub fn echo_messages(&self, input: &Input) -> String { + let messages = self.build_messages(input); + serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) + } + + pub fn build_messages(&self, input: &Input) -> Vec { + let mut messages = self.messages.clone(); + if input.continue_output().is_some() { + return messages; + } else if input.regenerate() { + while let Some(last) = messages.last() { + if !last.role.is_user() { + messages.pop(); + } else { + break; + } + } + return messages; + } + let mut need_add_msg = true; + let len = messages.len(); + if len == 0 { + messages = input.role().build_messages(input); + need_add_msg = false; + } else if len == 1 && self.compressed_messages.len() >= 2 { + if let Some(index) = self + .compressed_messages + .iter() + .rposition(|v| v.role == MessageRole::User) + { + messages.extend(self.compressed_messages[index..].to_vec()); + } + } + if need_add_msg { + messages.push(Message::new(MessageRole::User, input.message_content())); + } + messages + } +} + +impl RoleLike for Session { + fn to_role(&self) -> Role { + let role_name = self.role_name.as_deref().unwrap_or_default(); + let mut role = Role::new(role_name, &self.role_prompt); + role.sync(self); + role + } + + fn model(&self) -> &Model { + &self.model + } + + fn temperature(&self) -> Option { + self.temperature + } + + fn top_p(&self) -> Option { + self.top_p + } + + fn use_tools(&self) -> Option { + self.use_tools.clone() + } + + fn use_mcp_servers(&self) -> Option { + self.use_mcp_servers.clone() + } + + fn set_model(&mut self, model: Model) { + if self.model().id() != model.id() { + self.model_id = model.id(); + self.model = model; + self.dirty = true; + self.update_tokens(); + } + } + + fn set_temperature(&mut self, value: Option) { + if self.temperature != value { + self.temperature = value; + self.dirty = true; + } + } + + fn set_top_p(&mut self, value: Option) { + if self.top_p != value { + self.top_p = value; + self.dirty = true; + } + } + + fn set_use_tools(&mut self, value: Option) { + if self.use_tools != value { + self.use_tools = value; + self.dirty = true; + } + } + + fn set_use_mcp_servers(&mut self, value: Option) { + if self.use_mcp_servers != value { + self.use_mcp_servers = value; + self.dirty = true; + } + } +} + +#[derive(Debug, Clone, Default)] +struct AutoName { + naming: bool, + chat_history: Option, + name: Option, +} + +impl AutoName { + pub fn new(name: String) -> Self { + Self { + name: Some(name), + ..Default::default() + } + } + pub fn new_from_chat_history(chat_history: String) -> Self { + Self { + chat_history: Some(chat_history), + ..Default::default() + } + } + pub fn need(&self) -> bool { + !self.naming && self.chat_history.is_some() && self.name.is_none() + } +} diff --git a/src/function.rs b/src/function.rs new file mode 100644 index 0000000..0d61fb0 --- /dev/null +++ b/src/function.rs @@ -0,0 +1,825 @@ +use crate::{ + config::{Agent, Config, GlobalConfig}, + utils::*, +}; + +use crate::mcp::{MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX}; +use crate::parsers::{bash, python}; +use anyhow::{anyhow, bail, Context, Result}; +use indexmap::IndexMap; +use indoc::formatdoc; +use rust_embed::Embed; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::ffi::OsStr; +use std::fs::File; +use std::io::Write; +use std::{ + collections::{HashMap, HashSet}, + env, fs, io, + path::{Path, PathBuf}, +}; +use strum_macros::AsRefStr; + +#[derive(Embed)] +#[folder = "assets/functions/"] +struct FunctionAsset; + +#[cfg(windows)] +const PATH_SEP: &str = ";"; +#[cfg(not(windows))] +const PATH_SEP: &str = ":"; + +#[derive(AsRefStr)] +enum BinaryType { + Tool, + Agent, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)] +enum Language { + Bash, + Python, + Javascript, + Unsupported, +} + +impl From<&String> for Language { + fn from(s: &String) -> Self { + match s.to_lowercase().as_str() { + "sh" => Language::Bash, + "py" => Language::Python, + "js" => Language::Javascript, + _ => Language::Unsupported, + } + } +} + +#[cfg_attr(not(windows), expect(dead_code))] +impl Language { + fn to_cmd(self) -> &'static str { + match self { + Language::Bash => "bash", + Language::Python => "python", + Language::Javascript => "node", + Language::Unsupported => "sh", + } + } + + fn to_extension(self) -> &'static str { + match self { + Language::Bash => "sh", + Language::Python => "py", + Language::Javascript => "js", + _ => "sh", + } + } +} + +pub async fn eval_tool_calls( + config: &GlobalConfig, + mut calls: Vec, +) -> Result> { + let mut output = vec![]; + if calls.is_empty() { + return Ok(output); + } + calls = ToolCall::dedup(calls); + if calls.is_empty() { + bail!("The request was aborted because an infinite loop of function calls was detected.") + } + let mut is_all_null = true; + for call in calls { + let mut result = call.eval(config).await?; + if result.is_null() { + result = json!("DONE"); + } else { + is_all_null = false; + } + output.push(ToolResult::new(call, result)); + } + if is_all_null { + output = vec![]; + } + Ok(output) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolResult { + pub call: ToolCall, + pub output: Value, +} + +impl ToolResult { + pub fn new(call: ToolCall, output: Value) -> Self { + Self { call, output } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Functions { + declarations: Vec, +} + +impl Functions { + pub fn init() -> Result { + info!( + "Initializing global functions from {}", + Config::global_tools_file().display() + ); + + let declarations = Self { + declarations: Self::build_global_tool_declarations_from_path( + &Config::global_tools_file(), + )?, + }; + + info!( + "Building global function binaries in {}", + Config::functions_bin_dir().display() + ); + Self::build_global_function_binaries_from_path(Config::global_tools_file())?; + + Ok(declarations) + } + + pub fn init_agent(name: &str, global_tools: &[String]) -> Result { + let global_tools_declarations = if !global_tools.is_empty() { + let enabled_tools = global_tools.join("\n"); + info!("Loading global tools for agent: {name}: {enabled_tools}"); + let tools_declarations = Self::build_global_tool_declarations(&enabled_tools)?; + + info!( + "Building global function binaries required by agent: {name} in {}", + Config::functions_bin_dir().display() + ); + Self::build_global_function_binaries(&enabled_tools)?; + tools_declarations + } else { + debug!("No global tools found for agent: {}", name); + Vec::new() + }; + let agent_script_declarations = match Config::agent_functions_file(name) { + Ok(path) if path.exists() => { + info!( + "Loading functions script for agent: {name} from {}", + path.display() + ); + let script_declarations = Self::generate_declarations(&path)?; + debug!("agent_declarations: {:#?}", script_declarations); + + info!( + "Building function binary for agent: {name} in {}", + Config::agent_bin_dir(name).display() + ); + Self::build_agent_tool_binaries(name)?; + script_declarations + } + _ => { + debug!("No functions script found for agent: {}", name); + Vec::new() + } + }; + let declarations = [global_tools_declarations, agent_script_declarations].concat(); + + Ok(Self { declarations }) + } + + pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> { + self.declarations.iter().find(|v| v.name == name) + } + + pub fn contains(&self, name: &str) -> bool { + self.declarations.iter().any(|v| v.name == name) + } + + pub fn declarations(&self) -> &[FunctionDeclaration] { + &self.declarations + } + + pub fn is_empty(&self) -> bool { + self.declarations.is_empty() + } + + pub fn has_mcp_functions(&self) -> bool { + self.declarations.iter().any(|d| { + d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) + || d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + }) + } + + pub fn clear_mcp_meta_functions(&mut self) { + self.declarations.retain(|d| { + !d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) + && !d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + }); + } + + pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec) { + let mut invoke_function_properties = IndexMap::new(); + invoke_function_properties.insert( + "server".to_string(), + JsonSchema { + type_value: Some("string".to_string()), + ..Default::default() + }, + ); + invoke_function_properties.insert( + "tool".to_string(), + JsonSchema { + type_value: Some("string".to_string()), + ..Default::default() + }, + ); + invoke_function_properties.insert( + "arguments".to_string(), + JsonSchema { + type_value: Some("object".to_string()), + ..Default::default() + }, + ); + + for server in mcp_servers { + let invoke_function_name = format!("{}_{server}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); + let invoke_function_declaration = FunctionDeclaration { + name: invoke_function_name.clone(), + description: formatdoc!( + r#" + Invoke the specified tool on the {server} MCP server. Always call {invoke_function_name} first to find the + correct names of tools before calling '{invoke_function_name}'. + "# + ), + parameters: JsonSchema { + type_value: Some("object".to_string()), + properties: Some(invoke_function_properties.clone()), + required: Some(vec!["server".to_string(), "tool".to_string()]), + ..Default::default() + }, + agent: false, + }; + let list_functions_declaration = FunctionDeclaration { + name: format!("{}_{}", MCP_LIST_META_FUNCTION_NAME_PREFIX, server), + description: format!("List all the available tools for the {server} MCP server"), + parameters: JsonSchema::default(), + agent: false, + }; + self.declarations.push(invoke_function_declaration); + self.declarations.push(list_functions_declaration); + } + } + + fn build_global_tool_declarations(enabled_tools: &str) -> Result> { + let global_tools_directory = Config::global_tools_dir(); + let mut function_declarations = Vec::new(); + + for line in enabled_tools.lines() { + if line.starts_with('#') { + continue; + } + + let declaration = Self::generate_declarations(&global_tools_directory.join(line))?; + function_declarations.extend(declaration); + } + + Ok(function_declarations) + } + + fn build_global_tool_declarations_from_path( + tools_txt_path: &PathBuf, + ) -> Result> { + let enabled_tools = fs::read_to_string(tools_txt_path) + .with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?; + + Self::build_global_tool_declarations(&enabled_tools) + } + + fn generate_declarations(tools_file_path: &Path) -> Result> { + info!( + "Loading tool definitions from {}", + tools_file_path.display() + ); + let file_name = tools_file_path + .file_stem() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + anyhow::format_err!("Unable to extract file name from path: {tools_file_path:?}") + })?; + + match File::open(tools_file_path) { + Ok(tool_file) => { + let language = Language::from( + &tools_file_path + .extension() + .and_then(OsStr::to_str) + .map(|s| s.to_lowercase()) + .ok_or_else(|| { + anyhow!("Unable to extract language from tool file: {file_name}") + })?, + ); + + match language { + Language::Bash => { + bash::generate_bash_declarations(tool_file, tools_file_path, file_name) + } + Language::Python => python::generate_python_declarations( + tool_file, + file_name, + tools_file_path.parent(), + ), + Language::Unsupported => { + bail!("Unsupported tool file extension: {}", language.as_ref()) + } + _ => bail!("Unsupported tool language: {}", language.as_ref()), + } + } + Err(err) if err.kind() == io::ErrorKind::NotFound => { + bail!( + "Tool definition file not found: {}", + tools_file_path.display() + ); + } + Err(err) => bail!("Unable to open tool definition file. {}", err), + } + } + + fn build_global_function_binaries(enabled_tools: &str) -> Result<()> { + let bin_dir = Config::functions_bin_dir(); + if !bin_dir.exists() { + fs::create_dir_all(&bin_dir)?; + } + info!( + "Clearing existing function binaries in {}", + bin_dir.display() + ); + clear_dir(&bin_dir)?; + + for line in enabled_tools.lines() { + if line.starts_with('#') { + continue; + } + + let language = Language::from( + &Path::new(line) + .extension() + .and_then(OsStr::to_str) + .map(|s| s.to_lowercase()) + .ok_or_else(|| { + anyhow::format_err!("Unable to extract file extension from path: {line:?}") + })?, + ); + let binary_name = Path::new(line) + .file_stem() + .and_then(OsStr::to_str) + .ok_or_else(|| { + anyhow::format_err!("Unable to extract file name from path: {line:?}") + })?; + + if language == Language::Unsupported { + bail!("Unsupported tool file extension: {}", language.as_ref()); + } + + Self::build_binaries(binary_name, language, BinaryType::Tool)?; + } + + Ok(()) + } + + fn build_global_function_binaries_from_path(tools_txt_path: PathBuf) -> Result<()> { + let enabled_tools = fs::read_to_string(&tools_txt_path) + .with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?; + + Self::build_global_function_binaries(&enabled_tools) + } + + fn build_agent_tool_binaries(name: &str) -> Result<()> { + let agent_bin_directory = Config::agent_bin_dir(name); + if !agent_bin_directory.exists() { + debug!( + "Creating agent bin directory: {}", + agent_bin_directory.display() + ); + fs::create_dir_all(&agent_bin_directory)?; + } else { + debug!( + "Clearing existing agent bin directory: {}", + agent_bin_directory.display() + ); + clear_dir(&agent_bin_directory)?; + } + + let language = Language::from( + &Config::agent_functions_file(name)? + .extension() + .and_then(OsStr::to_str) + .map(|s| s.to_lowercase()) + .ok_or_else(|| { + anyhow::format_err!("Unable to extract file extension from path: {name:?}") + })?, + ); + + if language == Language::Unsupported { + bail!("Unsupported tool file extension: {}", language.as_ref()); + } + + Self::build_binaries(name, language, BinaryType::Agent) + } + + #[cfg(windows)] + fn build_binaries( + binary_name: &str, + language: Language, + binary_type: BinaryType, + ) -> Result<()> { + use native::runtime; + let (binary_file, binary_script_file) = match binary_type { + BinaryType::Tool => ( + Config::functions_bin_dir().join(format!("{binary_name}.cmd")), + Config::functions_bin_dir() + .join(format!("run-{binary_name}.{}", language.to_extension())), + ), + BinaryType::Agent => ( + Config::agent_bin_dir(binary_name).join(format!("{binary_name}.cmd")), + Config::agent_bin_dir(binary_name) + .join(format!("run-{binary_name}.{}", language.to_extension())), + ), + }; + info!( + "Building binary runner for function: {} ({})", + binary_name, + binary_script_file.display(), + ); + let embedded_file = FunctionAsset::get(&format!( + "scripts/run-{}.{}", + binary_type.as_ref().to_lowercase(), + language.to_extension() + )) + .ok_or_else(|| { + anyhow!( + "Failed to load embedded script for run-{}.{}", + binary_type.as_ref().to_lowercase(), + language.to_extension() + ) + })?; + let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) }; + let content = match binary_type { + BinaryType::Tool => content_template.replace("{function_name}", binary_name), + BinaryType::Agent => content_template.replace("{agent_name}", binary_name), + } + .replace("{config_dir}", &Config::config_dir().to_string_lossy()); + if binary_script_file.exists() { + fs::remove_file(&binary_script_file)?; + } + let mut script_file = File::create(&binary_script_file)?; + script_file.write_all(content.as_bytes())?; + + info!( + "Building binary for function: {} ({})", + binary_name, + binary_file.display() + ); + + let run = match language { + Language::Bash => { + let shell = runtime::bash_path().ok_or_else(|| anyhow!("Shell not found"))?; + format!("{shell} --noprofile --norc") + } + Language::Python if Path::new(".venv").exists() => { + let executable_path = env::current_dir()? + .join(".venv") + .join("Scripts") + .join("activate.bat"); + let canonicalized_path = fs::canonicalize(&executable_path)?; + format!( + "call \"{}\" && {}", + canonicalized_path.to_string_lossy(), + language.to_cmd() + ) + } + Language::Javascript => runtime::which(language.to_cmd()) + .ok_or_else(|| anyhow!("Unable to find {} in PATH", language.to_cmd()))?, + _ => bail!("Unsupported language: {}", language.as_ref()), + }; + let bin_dir = binary_file + .parent() + .expect("Failed to get parent directory of binary file") + .canonicalize()? + .to_string_lossy() + .into_owned(); + let wrapper_binary = binary_script_file + .canonicalize()? + .to_string_lossy() + .into_owned(); + let content = formatdoc!( + r#" + @echo off + setlocal + + set "bin_dir={bin_dir}" + + {run} "{wrapper_binary}" %*"#, + ); + + let mut file = File::create(&binary_file)?; + file.write_all(content.as_bytes())?; + + Ok(()) + } + + #[cfg(not(windows))] + fn build_binaries( + binary_name: &str, + language: Language, + binary_type: BinaryType, + ) -> Result<()> { + use std::os::unix::prelude::PermissionsExt; + + let binary_file = match binary_type { + BinaryType::Tool => Config::functions_bin_dir().join(binary_name), + BinaryType::Agent => Config::agent_bin_dir(binary_name).join(binary_name), + }; + info!( + "Building binary for function: {} ({})", + binary_name, + binary_file.display() + ); + let embedded_file = FunctionAsset::get(&format!( + "scripts/run-{}.{}", + binary_type.as_ref().to_lowercase(), + language.to_extension() + )) + .ok_or_else(|| { + anyhow!( + "Failed to load embedded script for run-{}.{}", + binary_type.as_ref().to_lowercase(), + language.to_extension() + ) + })?; + let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) }; + let content = match binary_type { + BinaryType::Tool => content_template.replace("{function_name}", binary_name), + BinaryType::Agent => content_template.replace("{agent_name}", binary_name), + } + .replace("{config_dir}", &Config::config_dir().to_string_lossy()); + if binary_file.exists() { + fs::remove_file(&binary_file)?; + } + let mut file = File::create(&binary_file)?; + file.write_all(content.as_bytes())?; + + fs::set_permissions(&binary_file, fs::Permissions::from_mode(0o755))?; + + Ok(()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: JsonSchema, + #[serde(skip_serializing, default)] + pub agent: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct JsonSchema { + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub type_value: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub items: Option>, + #[serde(rename = "anyOf", skip_serializing_if = "Option::is_none")] + pub any_of: Option>, + #[serde(rename = "enum", skip_serializing_if = "Option::is_none")] + pub enum_value: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, +} + +impl JsonSchema { + pub fn is_empty_properties(&self) -> bool { + match &self.properties { + Some(v) => v.is_empty(), + None => true, + } + } +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct ToolCall { + pub name: String, + pub arguments: Value, + pub id: Option, +} + +type CallConfig = (String, String, Vec, HashMap); + +impl ToolCall { + pub fn dedup(calls: Vec) -> Vec { + let mut new_calls = vec![]; + let mut seen_ids = HashSet::new(); + + for call in calls.into_iter().rev() { + if let Some(id) = &call.id { + if !seen_ids.contains(id) { + seen_ids.insert(id.clone()); + new_calls.push(call); + } + } else { + new_calls.push(call); + } + } + + new_calls.reverse(); + new_calls + } + + pub fn new(name: String, arguments: Value, id: Option) -> Self { + Self { + name, + arguments, + id, + } + } + + pub async fn eval(&self, config: &GlobalConfig) -> Result { + let (call_name, cmd_name, mut cmd_args, envs) = match &config.read().agent { + Some(agent) => self.extract_call_config_from_agent(config, agent)?, + None => self.extract_call_config_from_config(config)?, + }; + + let json_data = if self.arguments.is_object() { + self.arguments.clone() + } else if let Some(arguments) = self.arguments.as_str() { + let arguments: Value = serde_json::from_str(arguments).map_err(|_| { + anyhow!("The call '{call_name}' has invalid arguments: {arguments}") + })?; + arguments + } else { + bail!( + "The call '{call_name}' has invalid arguments: {}", + self.arguments + ); + }; + + cmd_args.push(json_data.to_string()); + + let prompt = format!("Call {cmd_name} {}", cmd_args.join(" ")); + + if *IS_STDOUT_TERMINAL { + println!("{}", dimmed_text(&prompt)); + } + + let output = match cmd_name.as_str() { + _ if cmd_name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) => { + let registry_arc = { + let cfg = config.read(); + cfg.mcp_registry + .clone() + .with_context(|| "MCP is not configured")? + }; + + registry_arc.catalog().await? + } + _ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => { + let server = json_data + .get("server") + .ok_or_else(|| anyhow!("Missing 'server' in arguments"))? + .as_str() + .ok_or_else(|| anyhow!("Invalid 'server' in arguments"))?; + let tool = json_data + .get("tool") + .ok_or_else(|| anyhow!("Missing 'tool' in arguments"))? + .as_str() + .ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?; + let arguments = json_data + .get("arguments") + .cloned() + .unwrap_or_else(|| json!({})); + let registry_arc = { + let cfg = config.read(); + cfg.mcp_registry + .clone() + .with_context(|| "MCP is not configured")? + }; + let result = registry_arc.invoke(server, tool, arguments).await?; + serde_json::to_value(result)? + } + _ => match run_llm_function(cmd_name, cmd_args, envs)? { + Some(contents) => serde_json::from_str(&contents) + .ok() + .unwrap_or_else(|| json!({"output": contents})), + None => Value::Null, + }, + }; + + Ok(output) + } + + fn extract_call_config_from_agent( + &self, + config: &GlobalConfig, + agent: &Agent, + ) -> Result { + let function_name = self.name.clone(); + match agent.functions().find(&function_name) { + Some(function) => { + let agent_name = agent.name().to_string(); + if function.agent { + Ok(( + format!("{agent_name}-{function_name}"), + agent_name, + vec![function_name], + agent.variable_envs(), + )) + } else { + Ok(( + function_name.clone(), + function_name, + vec![], + Default::default(), + )) + } + } + None => self.extract_call_config_from_config(config), + } + } + + fn extract_call_config_from_config(&self, config: &GlobalConfig) -> Result { + let function_name = self.name.clone(); + match config.read().functions.contains(&function_name) { + true => Ok(( + function_name.clone(), + function_name, + vec![], + Default::default(), + )), + false => bail!("Unexpected call: {function_name} {}", self.arguments), + } + } +} + +pub fn run_llm_function( + cmd_name: String, + cmd_args: Vec, + mut envs: HashMap, +) -> Result> { + let mut bin_dirs: Vec = vec![]; + if cmd_args.len() > 1 { + let dir = Config::agent_bin_dir(&cmd_name); + if dir.exists() { + bin_dirs.push(dir); + } + } + bin_dirs.push(Config::functions_bin_dir()); + let current_path = env::var("PATH").context("No PATH environment variable")?; + let prepend_path = bin_dirs + .iter() + .map(|v| format!("{}{PATH_SEP}", v.display())) + .collect::>() + .join(""); + envs.insert("PATH".into(), format!("{prepend_path}{current_path}")); + + let temp_file = temp_file("-eval-", ""); + envs.insert("LLM_OUTPUT".into(), temp_file.display().to_string()); + + #[cfg(windows)] + let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs); + + let exit_code = run_command(&cmd_name, &cmd_args, Some(envs)) + .map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?; + if exit_code != 0 { + bail!("Tool call exited with {exit_code}"); + } + let mut output = None; + if temp_file.exists() { + let contents = + fs::read_to_string(temp_file).context("Failed to retrieve tool call output")?; + if !contents.is_empty() { + debug!("Tool {cmd_name} output: {}", contents); + output = Some(contents); + } + }; + Ok(output) +} + +#[cfg(windows)] +fn polyfill_cmd_name>(cmd_name: &str, bin_dir: &[T]) -> String { + let cmd_name = cmd_name.to_string(); + if let Ok(exts) = env::var("PATHEXT") { + for name in exts.split(';').map(|ext| format!("{cmd_name}{ext}")) { + for dir in bin_dir { + let path = dir.as_ref().join(&name); + if path.exists() { + return name.to_string(); + } + } + } + } + cmd_name +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..7cfb156 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,496 @@ +mod cli; +mod client; +mod config; +mod function; +mod rag; +mod render; +mod repl; +mod serve; +#[macro_use] +mod utils; +mod mcp; +mod parsers; + +#[macro_use] +extern crate log; + +use crate::cli::Cli; +use crate::client::{ + call_chat_completions, call_chat_completions_streaming, list_models, ModelType, +}; +use crate::config::{ + ensure_parent_exists, list_agents, load_env_file, macro_execute, Agent, Config, GlobalConfig, + Input, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, TEMP_SESSION_NAME, +}; +use crate::render::render_error; +use crate::repl::Repl; +use crate::utils::*; + +use anyhow::{bail, Result}; +use clap::{CommandFactory, Parser}; +use clap_complete::CompleteEnv; +use inquire::Text; +use log::LevelFilter; +use log4rs::append::console::ConsoleAppender; +use log4rs::append::file::FileAppender; +use log4rs::config::{Appender, Logger, Root}; +use log4rs::encode::pattern::PatternEncoder; +use parking_lot::RwLock; +use std::path::PathBuf; +use std::{env, mem, process, sync::Arc}; + +#[tokio::main] +async fn main() -> Result<()> { + load_env_file()?; + CompleteEnv::with_factory(Cli::command).complete(); + let cli = Cli::parse(); + + if cli.tail_logs { + tail_logs(cli.disable_log_colors).await; + return Ok(()); + } + + let text = cli.text()?; + let working_mode = if cli.serve.is_some() { + WorkingMode::Serve + } else if text.is_none() && cli.file.is_empty() { + WorkingMode::Repl + } else { + WorkingMode::Cmd + }; + + let info_flag = cli.info + || cli.sync_models + || cli.list_models + || cli.list_roles + || cli.list_agents + || cli.list_rags + || cli.list_macros + || cli.list_sessions; + let log_path = setup_logger(working_mode.is_serve())?; + let abort_signal = create_abort_signal(); + let start_mcp_servers = cli.agent.is_none() && cli.role.is_none(); + let config = Arc::new(RwLock::new( + Config::init( + working_mode, + info_flag, + start_mcp_servers, + log_path, + abort_signal.clone(), + ) + .await?, + )); + if let Err(err) = run(config, cli, text, abort_signal).await { + render_error(err); + process::exit(1); + } + Ok(()) +} + +async fn run( + config: GlobalConfig, + cli: Cli, + text: Option, + abort_signal: AbortSignal, +) -> Result<()> { + if cli.sync_models { + let url = config.read().sync_models_url(); + return Config::sync_models(&url, abort_signal.clone()).await; + } + + if cli.list_models { + for model in list_models(&config.read(), ModelType::Chat) { + println!("{}", model.id()); + } + return Ok(()); + } + if cli.list_roles { + let roles = Config::list_roles(true).join("\n"); + println!("{roles}"); + return Ok(()); + } + if cli.list_agents { + let agents = list_agents().join("\n"); + println!("{agents}"); + return Ok(()); + } + if cli.list_rags { + let rags = Config::list_rags().join("\n"); + println!("{rags}"); + return Ok(()); + } + if cli.list_macros { + let macros = Config::list_macros().join("\n"); + println!("{macros}"); + return Ok(()); + } + + if cli.dry_run { + config.write().dry_run = true; + } + + if let Some(agent) = &cli.agent { + if cli.build_tools { + info!("Building tools for agent '{agent}'..."); + Agent::init(&config, agent, abort_signal.clone()).await?; + return Ok(()); + } + + let session = cli.session.as_ref().map(|v| match v { + Some(v) => v.as_str(), + None => TEMP_SESSION_NAME, + }); + if !cli.agent_variable.is_empty() { + config.write().agent_variables = Some( + cli.agent_variable + .chunks(2) + .map(|v| (v[0].to_string(), v[1].to_string())) + .collect(), + ); + } + + let ret = Config::use_agent(&config, agent, session, abort_signal.clone()).await; + config.write().agent_variables = None; + ret?; + } else { + if let Some(prompt) = &cli.prompt { + config.write().use_prompt(prompt)?; + } else if let Some(name) = &cli.role { + Config::use_role_safely(&config, name, abort_signal.clone()).await?; + } else if cli.execute { + Config::use_role_safely(&config, SHELL_ROLE, abort_signal.clone()).await?; + } else if cli.code { + Config::use_role_safely(&config, CODE_ROLE, abort_signal.clone()).await?; + } + if let Some(session) = &cli.session { + config + .write() + .use_session(session.as_ref().map(|v| v.as_str()))?; + } + if let Some(rag) = &cli.rag { + Config::use_rag(&config, Some(rag), abort_signal.clone()).await?; + } + } + + if cli.build_tools { + return Ok(()); + } + + if cli.list_sessions { + let sessions = config.read().list_sessions().join("\n"); + println!("{sessions}"); + return Ok(()); + } + if let Some(model_id) = &cli.model { + config.write().set_model(model_id)?; + } + if cli.no_stream { + config.write().stream = false; + } + if cli.empty_session { + config.write().empty_session()?; + } + if cli.save_session { + config.write().set_save_session_this_time()?; + } + if cli.info { + let info = config.read().info()?; + println!("{info}"); + return Ok(()); + } + if let Some(addr) = cli.serve { + return serve::run(config, addr).await; + } + let is_repl = config.read().working_mode.is_repl(); + if cli.rebuild_rag { + Config::rebuild_rag(&config, abort_signal.clone()).await?; + if is_repl { + return Ok(()); + } + } + if let Some(name) = &cli.macro_name { + macro_execute(&config, name, text.as_deref(), abort_signal.clone()).await?; + return Ok(()); + } + if cli.execute && !is_repl { + let input = create_input(&config, text, &cli.file, abort_signal.clone()).await?; + shell_execute(&config, &SHELL, input, abort_signal.clone()).await?; + return Ok(()); + } + + apply_prelude_safely(&config, abort_signal.clone()).await?; + + match is_repl { + false => { + let mut input = create_input(&config, text, &cli.file, abort_signal.clone()).await?; + input.use_embeddings(abort_signal.clone()).await?; + start_directive(&config, input, cli.code, abort_signal).await + } + true => { + if !*IS_STDOUT_TERMINAL { + bail!("No TTY for REPL") + } + start_interactive(&config).await + } + } +} + +async fn apply_prelude_safely(config: &RwLock, abort_signal: AbortSignal) -> Result<()> { + let mut cfg = { + let mut guard = config.write(); + mem::take(&mut *guard) + }; + + cfg.apply_prelude(abort_signal.clone()).await?; + + { + let mut guard = config.write(); + *guard = cfg; + } + + Ok(()) +} + +#[async_recursion::async_recursion] +async fn start_directive( + config: &GlobalConfig, + input: Input, + code_mode: bool, + abort_signal: AbortSignal, +) -> Result<()> { + let client = input.create_client()?; + let extract_code = !*IS_STDOUT_TERMINAL && code_mode; + config.write().before_chat_completion(&input)?; + let (output, tool_results) = if !input.stream() || extract_code { + call_chat_completions( + &input, + true, + extract_code, + client.as_ref(), + abort_signal.clone(), + ) + .await? + } else { + call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await? + }; + config + .write() + .after_chat_completion(&input, &output, &tool_results)?; + + if !tool_results.is_empty() { + start_directive( + config, + input.merge_tool_results(output, tool_results), + code_mode, + abort_signal, + ) + .await?; + } + + config.write().exit_session()?; + Ok(()) +} + +async fn start_interactive(config: &GlobalConfig) -> Result<()> { + let mut repl: Repl = Repl::init(config)?; + repl.run().await +} + +#[async_recursion::async_recursion] +async fn shell_execute( + config: &GlobalConfig, + shell: &Shell, + mut input: Input, + abort_signal: AbortSignal, +) -> Result<()> { + let client = input.create_client()?; + config.write().before_chat_completion(&input)?; + let (eval_str, _) = + call_chat_completions(&input, false, true, client.as_ref(), abort_signal.clone()).await?; + + config + .write() + .after_chat_completion(&input, &eval_str, &[])?; + if eval_str.is_empty() { + bail!("No command generated"); + } + if config.read().dry_run { + config.read().print_markdown(&eval_str)?; + return Ok(()); + } + if *IS_STDOUT_TERMINAL { + let options = ["execute", "revise", "describe", "copy", "quit"]; + let command = color_text(eval_str.trim(), nu_ansi_term::Color::Rgb(255, 165, 0)); + let first_letter_color = nu_ansi_term::Color::Cyan; + let prompt_text = options + .iter() + .map(|v| format!("{}{}", color_text(&v[0..1], first_letter_color), &v[1..])) + .collect::>() + .join(&dimmed_text(" | ")); + loop { + println!("{command}"); + let answer_char = + read_single_key(&['e', 'r', 'd', 'c', 'q'], 'e', &format!("{prompt_text}: "))?; + + match answer_char { + 'e' => { + debug!("{} {:?}", shell.cmd, &[&shell.arg, &eval_str]); + let code = run_command(&shell.cmd, &[&shell.arg, &eval_str], None)?; + if code == 0 && config.read().save_shell_history { + let _ = append_to_shell_history(&shell.name, &eval_str, code); + } + process::exit(code); + } + 'r' => { + let revision = Text::new("Enter your revision:").prompt()?; + let text = format!("{}\n{revision}", input.text()); + input.set_text(text); + return shell_execute(config, shell, input, abort_signal.clone()).await; + } + 'd' => { + let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?; + let input = Input::from_str(config, &eval_str, Some(role)); + if input.stream() { + call_chat_completions_streaming( + &input, + client.as_ref(), + abort_signal.clone(), + ) + .await?; + } else { + call_chat_completions( + &input, + true, + false, + client.as_ref(), + abort_signal.clone(), + ) + .await?; + } + println!(); + continue; + } + 'c' => { + set_text(&eval_str)?; + println!("{}", dimmed_text("✓ Copied the command.")); + } + _ => {} + } + break; + } + } else { + println!("{eval_str}"); + } + Ok(()) +} + +async fn create_input( + config: &GlobalConfig, + text: Option, + file: &[String], + abort_signal: AbortSignal, +) -> Result { + let input = if file.is_empty() { + Input::from_str(config, &text.unwrap_or_default(), None) + } else { + Input::from_files_with_spinner( + config, + &text.unwrap_or_default(), + file.to_vec(), + None, + abort_signal, + ) + .await? + }; + if input.is_empty() { + bail!("No input"); + } + Ok(input) +} + +fn setup_logger(is_serve: bool) -> Result> { + let (log_level, log_path) = Config::log_config(is_serve)?; + if log_level == LevelFilter::Off { + return Ok(None); + } + let encoder = Box::new(PatternEncoder::new( + "{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}", + )); + let log_filter = match env::var(get_env_name("log_filter")) { + Ok(v) => Some(v), + Err(_) => match is_serve { + true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))), + false => None, + }, + }; + + match log_path.clone() { + None => { + let console_appender = ConsoleAppender::builder().encoder(encoder).build(); + log4rs::init_config(init_console_logger(log_level, log_filter, console_appender))?; + } + Some(path) => { + ensure_parent_exists(&path)?; + let file_appender = FileAppender::builder().encoder(encoder.clone()).build(path); + + match file_appender { + Ok(appender) => { + log4rs::init_config(init_file_logger(log_level, log_filter, appender))? + } + Err(_) => { + let console_appender = ConsoleAppender::builder().encoder(encoder).build(); + log4rs::init_config(init_console_logger( + log_level, + log_filter, + console_appender, + ))? + } + }; + } + } + Ok(log_path) +} + +fn init_file_logger( + log_level: LevelFilter, + log_filter: Option, + file_appender: FileAppender, +) -> log4rs::Config { + let root_log_level = if log_filter.is_some() { + LevelFilter::Off + } else { + log_level + }; + let mut config_builder = log4rs::Config::builder() + .appender(Appender::builder().build("logfile", Box::new(file_appender))); + + if let Some(filter) = log_filter { + config_builder = config_builder.logger(Logger::builder().build(filter, log_level)); + } + + config_builder + .build(Root::builder().appender("logfile").build(root_log_level)) + .unwrap() +} + +fn init_console_logger( + log_level: LevelFilter, + log_filter: Option, + console_appender: ConsoleAppender, +) -> log4rs::Config { + let root_log_level = if log_filter.is_some() { + LevelFilter::Off + } else { + log_level + }; + let mut config_builder = log4rs::Config::builder() + .appender(Appender::builder().build("console", Box::new(console_appender))); + + if let Some(filter) = log_filter { + config_builder = config_builder.logger(Logger::builder().build(filter, log_level)); + } + + config_builder + .build(Root::builder().appender("console").build(root_log_level)) + .unwrap() +} diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs new file mode 100644 index 0000000..e708679 --- /dev/null +++ b/src/mcp/mod.rs @@ -0,0 +1,290 @@ +use crate::config::Config; +use crate::utils::{abortable_run_with_spinner, AbortSignal}; +use anyhow::{anyhow, Context, Result}; +use futures_util::future::BoxFuture; +use futures_util::{stream, StreamExt, TryStreamExt}; +use rmcp::model::{CallToolRequestParam, CallToolResult}; +use rmcp::service::RunningService; +use rmcp::transport::TokioChildProcess; +use rmcp::{RoleClient, ServiceExt}; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; +use std::fs::OpenOptions; +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Arc; +use tokio::process::Command; + +pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; +pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list"; + +type ConnectedServer = RunningService; + +#[derive(Debug, Clone, Deserialize)] +struct McpServersConfig { + #[serde(rename = "mcpServers")] + mcp_servers: HashMap, +} + +#[derive(Debug, Clone, Deserialize)] +struct McpServer { + command: String, + args: Option>, + env: Option>, + cwd: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +enum JsonField { + Str(String), + Bool(bool), + Int(i64), +} + +#[derive(Debug, Clone, Default)] +pub struct McpRegistry { + log_path: Option, + config: Option, + servers: HashMap>>, +} + +impl McpRegistry { + pub async fn init( + log_path: Option, + start_mcp_servers: bool, + use_mcp_servers: Option, + abort_signal: AbortSignal, + ) -> Result { + let mut registry = Self { + log_path, + ..Default::default() + }; + if !Config::mcp_config_file().try_exists().with_context(|| { + format!( + "Failed to check MCP config file at {}", + Config::mcp_config_file().display() + ) + })? { + debug!( + "MCP config file does not exist at {}, skipping MCP initialization", + Config::mcp_config_file().display() + ); + return Ok(registry); + } + let err = || { + format!( + "Failed to load MCP config file at {}", + Config::mcp_config_file().display() + ) + }; + let content = tokio::fs::read_to_string(Config::mcp_config_file()) + .await + .with_context(err)?; + let config: McpServersConfig = serde_json::from_str(&content).with_context(err)?; + registry.config = Some(config); + + if start_mcp_servers { + abortable_run_with_spinner( + registry.start_select_mcp_servers(use_mcp_servers), + "Loading MCP servers", + abort_signal, + ) + .await?; + } + + Ok(registry) + } + + pub async fn reinit( + registry: McpRegistry, + use_mcp_servers: Option, + abort_signal: AbortSignal, + ) -> Result { + debug!("Reinitializing MCP registry"); + debug!("Stopping all MCP servers"); + let mut new_registry = abortable_run_with_spinner( + registry.stop_all_servers(), + "Stopping MCP servers", + abort_signal.clone(), + ) + .await?; + + abortable_run_with_spinner( + new_registry.start_select_mcp_servers(use_mcp_servers), + "Loading MCP servers", + abort_signal, + ) + .await?; + + Ok(new_registry) + } + + async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option) -> Result<()> { + if self.config.is_none() { + debug!("MCP config is not present; assuming MCP servers are disabled globally. skipping MCP initialization"); + return Ok(()); + } + + debug!("Starting selected MCP servers: {:?}", use_mcp_servers); + + if let Some(servers) = use_mcp_servers { + let config = self + .config + .as_ref() + .with_context(|| "MCP Config not defined. Cannot start servers")?; + let mcp_servers = config.mcp_servers.clone(); + + let enabled_servers: HashSet = + servers.split(',').map(|s| s.trim().to_string()).collect(); + let server_ids: Vec = if servers == "all" { + mcp_servers.into_keys().collect() + } else { + mcp_servers + .into_keys() + .filter(|id| enabled_servers.contains(id)) + .collect() + }; + + let results: Vec<(String, Arc<_>)> = stream::iter( + server_ids + .into_iter() + .map(|id| async { self.start_server(id).await }), + ) + .buffer_unordered(num_cpus::get()) + .try_collect() + .await?; + + self.servers = results.into_iter().collect(); + } + + Ok(()) + } + + async fn start_server(&self, id: String) -> Result<(String, Arc)> { + let server = self + .config + .as_ref() + .and_then(|c| c.mcp_servers.get(&id)) + .with_context(|| format!("MCP server not found in config: {id}"))?; + let mut cmd = Command::new(&server.command); + if let Some(args) = &server.args { + cmd.args(args); + } + if let Some(env) = &server.env { + let env: HashMap = env + .iter() + .map(|(k, v)| match v { + JsonField::Str(s) => (k.clone(), s.clone()), + JsonField::Bool(b) => (k.clone(), b.to_string()), + JsonField::Int(i) => (k.clone(), i.to_string()), + }) + .collect(); + cmd.envs(env); + } + if let Some(cwd) = &server.cwd { + cmd.current_dir(cwd); + } + + let transport = if let Some(log_path) = self.log_path.as_ref() { + cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); + + let log_file = OpenOptions::new() + .create(true) + .append(true) + .open(log_path)?; + let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?; + transport + } else { + TokioChildProcess::new(cmd)? + }; + + let service = Arc::new( + ().serve(transport) + .await + .with_context(|| format!("Failed to start MCP server: {}", &server.command))?, + ); + debug!( + "Available tools for MCP server {id}: {:?}", + service.list_tools(None).await? + ); + + info!("Started MCP server: {id}"); + + Ok((id.to_string(), service)) + } + + pub async fn stop_all_servers(mut self) -> Result { + for (id, server) in self.servers { + Arc::try_unwrap(server) + .map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? + .cancel() + .await + .with_context(|| format!("Failed to stop MCP server: {id}"))?; + info!("Stopped MCP server: {id}"); + } + + self.servers = HashMap::new(); + + Ok(self) + } + + pub fn list_servers(&self) -> Vec { + self.servers.keys().cloned().collect() + } + + pub fn catalog(&self) -> BoxFuture<'static, Result> { + let servers: Vec<(String, Arc)> = self + .servers + .iter() + .map(|(id, s)| (id.clone(), s.clone())) + .collect(); + + Box::pin(async move { + let mut out = Vec::with_capacity(servers.len()); + for (id, server) in servers { + let tools = server.list_tools(None).await?; + let resources = server.list_resources(None).await.unwrap_or_default(); + // TODO implement prompt sampling for MCP servers + // let prompts = server.service.list_prompts(None).await.unwrap_or_default(); + out.push(json!({ + "server": id, + "tools": tools, + "resources": resources, + })); + } + Ok(Value::Array(out)) + }) + } + + pub fn invoke( + &self, + server: &str, + tool: &str, + arguments: Value, + ) -> BoxFuture<'static, Result> { + let server = self + .servers + .get(server) + .cloned() + .with_context(|| format!("Invoked MCP server does not exist: {server}")); + + let tool = tool.to_owned(); + Box::pin(async move { + let server = server?; + let call_tool_request = CallToolRequestParam { + name: Cow::Owned(tool.to_owned()), + arguments: arguments.as_object().cloned(), + }; + + let result = server.call_tool(call_tool_request).await?; + Ok(result) + }) + } + + pub fn is_empty(&self) -> bool { + self.servers.is_empty() + } +} diff --git a/src/parsers/bash.rs b/src/parsers/bash.rs new file mode 100644 index 0000000..c237c81 --- /dev/null +++ b/src/parsers/bash.rs @@ -0,0 +1,149 @@ +use crate::function::{FunctionDeclaration, JsonSchema}; +use anyhow::{bail, Context, Result}; +use argc::{ChoiceValue, CommandValue, FlagOptionValue}; +use indexmap::IndexMap; +use std::fs::File; +use std::io::Read; +use std::path::Path; +use std::{env, fs}; + +pub fn generate_bash_declarations( + mut tool_file: File, + tools_file_path: &Path, + file_name: &str, +) -> Result> { + let mut src = String::new(); + tool_file + .read_to_string(&mut src) + .with_context(|| format!("Failed to load script at '{tool_file:?}'"))?; + + debug!("Building script at '{tool_file:?}'"); + let build_script = argc::build( + &src, + "", + env::var("TERM_WIDTH").ok().and_then(|v| v.parse().ok()), + )?; + fs::write(tools_file_path, &build_script) + .with_context(|| format!("Failed to write built script to '{tools_file_path:?}'"))?; + + let command_value = argc::export(&build_script, file_name) + .with_context(|| format!("Failed to parse script at '{tool_file:?}'"))?; + if command_value.subcommands.is_empty() { + let function_declaration = + command_to_function_declaration(&command_value).ok_or_else(|| { + anyhow::format_err!("Tool definition missing or empty description: {file_name}") + })?; + Ok(vec![function_declaration]) + } else { + let mut declarations = vec![]; + for subcommand in &command_value.subcommands { + if subcommand.name.starts_with('_') && subcommand.name != "_instructions" { + continue; + } + + if let Some(mut function_declaration) = command_to_function_declaration(subcommand) { + function_declaration.agent = true; + declarations.push(function_declaration); + } else { + bail!( + "Tool definition missing or empty description: {} {}", + file_name, + subcommand.name + ); + } + } + + Ok(declarations) + } +} + +fn command_to_function_declaration(cmd: &CommandValue) -> Option { + if cmd.describe.is_empty() { + return None; + } + + Some(FunctionDeclaration { + name: underscore(&cmd.name), + description: cmd.describe.clone(), + parameters: parse_parameters_schema(&cmd.flag_options), + agent: false, + }) +} + +fn underscore(s: &str) -> String { + s.replace('-', "_") +} + +fn schema_ty(t: &str) -> JsonSchema { + JsonSchema { + type_value: Some(t.to_string()), + description: None, + properties: None, + items: None, + any_of: None, + enum_value: None, + default: None, + required: None, + } +} + +fn with_description(mut schema: JsonSchema, describe: &str) -> JsonSchema { + if !describe.is_empty() { + schema.description = Some(describe.to_string()); + } + schema +} + +fn with_enum(mut schema: JsonSchema, choice: &Option) -> JsonSchema { + if let Some(ChoiceValue::Values(values)) = choice { + if !values.is_empty() { + schema.enum_value = Some(values.clone()); + } + } + schema +} + +fn parse_property(flag: &FlagOptionValue) -> JsonSchema { + let mut schema = if flag.flag { + schema_ty("boolean") + } else if flag.multiple_occurs { + let mut arr = schema_ty("array"); + arr.items = Some(Box::new(schema_ty("string"))); + arr + } else if flag.notations.first().map(|s| s.as_str()) == Some("INT") { + schema_ty("integer") + } else if flag.notations.first().map(|s| s.as_str()) == Some("NUM") { + schema_ty("number") + } else { + schema_ty("string") + }; + + schema = with_description(schema, &flag.describe); + schema = with_enum(schema, &flag.choice); + schema +} + +fn parse_parameters_schema(flags: &[FlagOptionValue]) -> JsonSchema { + let filtered = flags.iter().filter(|f| f.id != "help" && f.id != "version"); + let mut props: IndexMap = IndexMap::new(); + let mut required: Vec = Vec::new(); + + for f in filtered { + let key = underscore(&f.id); + if f.required { + required.push(key.clone()); + } + props.insert(key, parse_property(f)); + } + + JsonSchema { + type_value: Some("object".to_string()), + description: None, + properties: Some(props), + items: None, + any_of: None, + enum_value: None, + default: None, + required: Some(required), + } +} diff --git a/src/parsers/mod.rs b/src/parsers/mod.rs new file mode 100644 index 0000000..57ae15a --- /dev/null +++ b/src/parsers/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod bash; +pub(crate) mod python; diff --git a/src/parsers/python.rs b/src/parsers/python.rs new file mode 100644 index 0000000..d963e5e --- /dev/null +++ b/src/parsers/python.rs @@ -0,0 +1,420 @@ +use crate::function::{FunctionDeclaration, JsonSchema}; +use anyhow::{bail, Context, Result}; +use ast::{Stmt, StmtFunctionDef}; +use indexmap::IndexMap; +use rustpython_ast::{Constant, Expr, UnaryOp}; +use rustpython_parser::{ast, Mode}; +use serde_json::Value; +use std::fs::File; +use std::io::Read; +use std::path::Path; + +#[derive(Debug)] +struct Param { + name: String, + ty_hint: String, + required: bool, + default: Option, + doc_type: Option, + doc_desc: Option, +} + +pub fn generate_python_declarations( + mut tool_file: File, + file_name: &str, + parent: Option<&Path>, +) -> Result> { + let mut src = String::new(); + tool_file + .read_to_string(&mut src) + .with_context(|| format!("Failed to load script at '{tool_file:?}'"))?; + let suite = parse_suite(&src, file_name)?; + + let is_tool = parent + .and_then(|p| p.file_name()) + .is_some_and(|n| n == "tools"); + let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?; + + if is_tool { + for d in &mut declarations { + d.agent = true; + } + } + + Ok(declarations) +} + +fn parse_suite(src: &str, filename: &str) -> Result { + let mod_ast = + rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?; + + let suite = match mod_ast { + ast::Mod::Module(m) => m.body, + ast::Mod::Interactive(m) => m.body, + ast::Mod::Expression(_) => bail!("expected a module; got a single expression"), + _ => bail!("unexpected parse mode/AST variant"), + }; + + Ok(suite) +} + +fn python_to_function_declarations( + file_name: &str, + module: &ast::Suite, + is_tool: bool, +) -> Result> { + let mut out = Vec::new(); + + for stmt in module { + if let Stmt::FunctionDef(fd) = stmt { + let func_name = fd.name.to_string(); + + if func_name.starts_with('_') && func_name != "_instructions" { + continue; + } + + if is_tool && func_name != "run" { + continue; + } + + let description = get_docstring_from_body(&fd.body).unwrap_or_default(); + let params = collect_params(fd); + let schema = build_parameters_schema(¶ms, &description); + let name = if is_tool && func_name == "run" { + underscore(file_name) + } else { + underscore(&func_name) + }; + let desc_trim = description.trim().to_string(); + if desc_trim.is_empty() { + bail!("Missing or empty description on function: {func_name}"); + } + + out.push(FunctionDeclaration { + name, + description: desc_trim, + parameters: schema, + agent: !is_tool, + }); + } + } + + Ok(out) +} + +fn get_docstring_from_body(body: &[Stmt]) -> Option { + let first = body.first()?; + if let Stmt::Expr(expr_stmt) = first { + if let Expr::Constant(constant) = &*expr_stmt.value { + if let Constant::Str(s) = &constant.value { + return Some(s.clone()); + } + } + } + None +} + +fn collect_params(fd: &StmtFunctionDef) -> Vec { + let mut out = Vec::new(); + + for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) { + let name = a.def.arg.to_string(); + let mut ty = get_arg_type(a.def.annotation.as_deref()); + let mut required = a.default.is_none(); + + if ty.ends_with('?') { + ty.pop(); + required = false; + } + + let default = if a.default.is_some() { + Some(Value::Null) + } else { + None + }; + + out.push(Param { + name, + ty_hint: ty, + required, + default, + doc_type: None, + doc_desc: None, + }); + } + + for a in &fd.args.kwonlyargs { + let name = a.def.arg.to_string(); + let mut ty = get_arg_type(a.def.annotation.as_deref()); + let mut required = a.default.is_none(); + + if ty.ends_with('?') { + ty.pop(); + required = false; + } + + let default = if a.default.is_some() { + Some(Value::Null) + } else { + None + }; + + out.push(Param { + name, + ty_hint: ty, + required, + default, + doc_type: None, + doc_desc: None, + }); + } + + if let Some(vararg) = &fd.args.vararg { + let name = vararg.arg.to_string(); + let inner = get_arg_type(vararg.annotation.as_deref()); + let ty = if inner.is_empty() { + "list[str]".into() + } else { + format!("list[{inner}]") + }; + + out.push(Param { + name, + ty_hint: ty, + required: false, + default: None, + doc_type: None, + doc_desc: None, + }); + } + + if let Some(kwarg) = &fd.args.kwarg { + let name = kwarg.arg.to_string(); + out.push(Param { + name, + ty_hint: "object".into(), + required: false, + default: None, + doc_type: None, + doc_desc: None, + }); + } + + if let Some(doc) = get_docstring_from_body(&fd.body) { + let meta = parse_docstring_args(&doc); + for p in &mut out { + if let Some((t, d)) = meta.get(&p.name) { + if !t.is_empty() { + p.doc_type = Some(t.clone()); + } + + if !d.is_empty() { + p.doc_desc = Some(d.clone()); + } + + if t.ends_with('?') { + p.required = false; + } + } + } + } + + out +} + +fn get_arg_type(annotation: Option<&Expr>) -> String { + match annotation { + None => "".to_string(), + Some(Expr::Name(n)) => n.id.to_string(), + Some(Expr::Subscript(sub)) => match &*sub.value { + Expr::Name(name) if &name.id == "Optional" => { + let inner = get_arg_type(Some(&sub.slice)); + format!("{inner}?") + } + Expr::Name(name) if &name.id == "List" => { + let inner = get_arg_type(Some(&sub.slice)); + format!("list[{inner}]") + } + Expr::Name(name) if &name.id == "Literal" => { + let vals = literal_members(&sub.slice); + format!("literal:{}", vals.join("|")) + } + _ => "any".to_string(), + }, + _ => "any".to_string(), + } +} + +fn expr_to_str(e: &Expr) -> String { + match e { + Expr::Constant(c) => match &c.value { + Constant::Str(s) => s.clone(), + Constant::Int(i) => i.to_string(), + Constant::Float(f) => f.to_string(), + Constant::Bool(b) => b.to_string(), + Constant::None => "None".to_string(), + Constant::Ellipsis => "...".to_string(), + Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(), + Constant::Complex { real, imag } => format!("{real}+{imag}j"), + _ => "any".to_string(), + }, + + Expr::Name(n) => n.id.to_string(), + + Expr::UnaryOp(u) => { + if matches!(u.op, UnaryOp::USub) { + let inner = expr_to_str(&u.operand); + if inner.parse::().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) { + return format!("-{inner}"); + } + } + "any".to_string() + } + + Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::>().join(","), + + _ => "any".to_string(), + } +} + +fn literal_members(e: &Expr) -> Vec { + match e { + Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(), + _ => vec![expr_to_str(e)], + } +} + +fn parse_docstring_args(doc: &str) -> IndexMap { + let mut out = IndexMap::new(); + let mut in_args = false; + for line in doc.lines() { + if !in_args { + if line.trim_start().starts_with("Args:") { + in_args = true; + } + continue; + } + if !(line.starts_with(' ') || line.starts_with('\t')) { + break; + } + let s = line.trim(); + if let Some((left, desc)) = s.split_once(':') { + let left = left.trim(); + let mut name = left.to_string(); + let mut ty = String::new(); + if let Some((n, t)) = left.split_once(' ') { + name = n.trim().to_string(); + ty = t.trim().to_string(); + if ty.starts_with('(') && ty.ends_with(')') { + let mut inner = ty[1..ty.len() - 1].to_string(); + if inner.to_lowercase().contains("optional") && !inner.ends_with('?') { + inner.push('?'); + } + ty = inner; + } + } + out.insert(name, (ty, desc.trim().to_string())); + } + } + out +} + +fn underscore(s: &str) -> String { + s.chars() + .map(|c| { + if c.is_ascii_alphanumeric() { + c.to_ascii_lowercase() + } else { + '_' + } + }) + .collect::() + .split('_') + .filter(|t| !t.is_empty()) + .collect::>() + .join("_") +} + +fn build_parameters_schema(params: &[Param], _description: &str) -> JsonSchema { + let mut props: IndexMap = IndexMap::new(); + let mut req: Vec = Vec::new(); + + for p in params { + let name = p.name.replace('-', "_"); + let mut schema = JsonSchema::default(); + + let ty = if !p.ty_hint.is_empty() { + p.ty_hint.as_str() + } else if let Some(t) = &p.doc_type { + t.as_str() + } else { + "str" + }; + + if let Some(d) = &p.doc_desc { + if !d.is_empty() { + schema.description = Some(d.clone()); + } + } + + apply_type_to_schema(ty, &mut schema); + + if p.default.is_none() && p.required { + req.push(name.clone()); + } + + props.insert(name, schema); + } + + JsonSchema { + type_value: Some("object".into()), + description: None, + properties: Some(props), + items: None, + any_of: None, + enum_value: None, + default: None, + required: if req.is_empty() { None } else { Some(req) }, + } +} + +fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) { + let t = ty.trim_end_matches('?'); + if let Some(rest) = t.strip_prefix("list[") { + s.type_value = Some("array".into()); + let inner = rest.trim_end_matches(']'); + let mut item = JsonSchema::default(); + + apply_type_to_schema(inner, &mut item); + + if item.type_value.is_none() { + item.type_value = Some("string".into()); + } + s.items = Some(Box::new(item)); + + return; + } + + if let Some(rest) = t.strip_prefix("literal:") { + s.type_value = Some("string".into()); + let vals = rest + .split('|') + .map(|x| x.trim().trim_matches('"').trim_matches('\'').to_string()) + .collect::>(); + if !vals.is_empty() { + s.enum_value = Some(vals); + } + return; + } + + s.type_value = Some( + match t { + "bool" => "boolean", + "int" => "integer", + "float" => "number", + "str" | "any" | "" => "string", + _ => "string", + } + .into(), + ); +} diff --git a/src/rag/mod.rs b/src/rag/mod.rs new file mode 100644 index 0000000..be12040 --- /dev/null +++ b/src/rag/mod.rs @@ -0,0 +1,1013 @@ +use self::splitter::*; + +use crate::client::*; +use crate::config::*; +use crate::utils::*; + +mod serde_vectors; +mod splitter; + +use anyhow::{anyhow, bail, Context, Result}; +use bm25::{Language, SearchEngine, SearchEngineBuilder}; +use hnsw_rs::prelude::*; +use indexmap::{IndexMap, IndexSet}; +use inquire::{required, validator::Validation, Confirm, Select, Text}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration}; +use tokio::time::sleep; + +pub struct Rag { + config: GlobalConfig, + name: String, + path: String, + embedding_model: Model, + hnsw: Hnsw<'static, f32, DistCosine>, + bm25: SearchEngine, + data: RagData, + last_sources: RwLock>, +} + +impl Debug for Rag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Rag") + .field("name", &self.name) + .field("path", &self.path) + .field("embedding_model", &self.embedding_model) + .field("data", &self.data) + .finish() + } +} + +impl Clone for Rag { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + name: self.name.clone(), + path: self.path.clone(), + embedding_model: self.embedding_model.clone(), + hnsw: self.data.build_hnsw(), + bm25: self.data.build_bm25(), + data: self.data.clone(), + last_sources: RwLock::new(None), + } + } +} + +impl Rag { + pub async fn init( + config: &GlobalConfig, + name: &str, + save_path: &Path, + doc_paths: &[String], + abort_signal: AbortSignal, + ) -> Result { + if !*IS_STDOUT_TERMINAL { + bail!("Failed to init rag in non-interactive mode"); + } + println!("⚙ Initializing RAG..."); + let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(config)?; + let (reranker_model, top_k) = { + let config = config.read(); + (config.rag_reranker_model.clone(), config.rag_top_k) + }; + let data = RagData::new( + embedding_model.id(), + chunk_size, + chunk_overlap, + reranker_model, + top_k, + embedding_model.max_batch_size(), + ); + let mut rag = Self::create(config, name, save_path, data)?; + let mut paths = doc_paths.to_vec(); + if paths.is_empty() { + paths = add_documents()?; + }; + let loaders = config.read().document_loaders.clone(); + let (spinner, spinner_rx) = Spinner::create(""); + abortable_run_with_spinner_rx( + rag.sync_documents(&paths, true, loaders, Some(spinner)), + spinner_rx, + abort_signal, + ) + .await?; + if rag.save()? { + println!("✓ Saved RAG to '{}'.", save_path.display()); + } + Ok(rag) + } + + pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result { + let err = || format!("Failed to load rag '{name}' at '{}'", path.display()); + let content = fs::read_to_string(path).with_context(err)?; + let data: RagData = serde_yaml::from_str(&content).with_context(err)?; + Self::create(config, name, path, data) + } + + pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result { + let hnsw = data.build_hnsw(); + let bm25 = data.build_bm25(); + let embedding_model = + Model::retrieve_model(&config.read(), &data.embedding_model, ModelType::Embedding)?; + let rag = Rag { + config: config.clone(), + name: name.to_string(), + path: path.display().to_string(), + data, + embedding_model, + hnsw, + bm25, + last_sources: RwLock::new(None), + }; + Ok(rag) + } + + pub fn document_paths(&self) -> &[String] { + &self.data.document_paths + } + + pub async fn refresh_document_paths( + &mut self, + document_paths: &[String], + refresh: bool, + config: &GlobalConfig, + abort_signal: AbortSignal, + ) -> Result<()> { + let loaders = config.read().document_loaders.clone(); + let (spinner, spinner_rx) = Spinner::create(""); + abortable_run_with_spinner_rx( + self.sync_documents(document_paths, refresh, loaders, Some(spinner)), + spinner_rx, + abort_signal, + ) + .await?; + if self.save()? { + println!("✓ Saved rag to '{}'.", self.path); + } + Ok(()) + } + + pub fn create_config(config: &GlobalConfig) -> Result<(Model, usize, usize)> { + let (embedding_model_id, chunk_size, chunk_overlap) = { + let config = config.read(); + ( + config.rag_embedding_model.clone(), + config.rag_chunk_size, + config.rag_chunk_overlap, + ) + }; + let embedding_model_id = match embedding_model_id { + Some(value) => { + println!("Select embedding model: {value}"); + value + } + None => { + let models = list_models(&config.read(), ModelType::Embedding); + if models.is_empty() { + bail!("No available embedding model"); + } + select_embedding_model(&models)? + } + }; + let embedding_model = + Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; + + let chunk_size = match chunk_size { + Some(value) => { + println!("Set chunk size: {value}"); + value + } + None => set_chunk_size(&embedding_model)?, + }; + let chunk_overlap = match chunk_overlap { + Some(value) => { + println!("Set chunk overlay: {value}"); + value + } + None => { + let value = chunk_size / 20; + set_chunk_overlay(value)? + } + }; + + Ok((embedding_model, chunk_size, chunk_overlap)) + } + + pub fn get_config(&self) -> (Option, usize) { + (self.data.reranker_model.clone(), self.data.top_k) + } + + pub fn get_last_sources(&self) -> Option { + self.last_sources.read().clone() + } + + pub fn set_last_sources(&self, ids: &[DocumentId]) { + let mut sources: IndexMap> = IndexMap::new(); + for id in ids { + let (file_index, _) = id.split(); + if let Some(file) = self.data.files.get(&file_index) { + sources + .entry(file.path.clone()) + .or_default() + .push(format!("{id:?}")); + } + } + let sources = if sources.is_empty() { + None + } else { + Some( + sources + .into_iter() + .map(|(path, ids)| format!("{path} ({})", ids.join(","))) + .collect::>() + .join("\n"), + ) + }; + *self.last_sources.write() = sources; + } + + pub fn set_reranker_model(&mut self, reranker_model: Option) -> Result<()> { + self.data.reranker_model = reranker_model; + self.save()?; + Ok(()) + } + + pub fn set_top_k(&mut self, top_k: usize) -> Result<()> { + self.data.top_k = top_k; + self.save()?; + Ok(()) + } + + pub fn save(&self) -> Result { + if self.is_temp() { + return Ok(false); + } + let path = Path::new(&self.path); + ensure_parent_exists(path)?; + + let content = serde_yaml::to_string(&self.data) + .with_context(|| format!("Failed to serde rag '{}'", self.name))?; + fs::write(path, content).with_context(|| { + format!("Failed to save rag '{}' to '{}'", self.name, path.display()) + })?; + + Ok(true) + } + + pub fn export(&self) -> Result { + let files: Vec<_> = self + .data + .files + .iter() + .map(|(_, v)| { + json!({ + "path": v.path, + "num_chunks": v.documents.len(), + }) + }) + .collect(); + let data = json!({ + "path": self.path, + "embedding_model": self.embedding_model.id(), + "chunk_size": self.data.chunk_size, + "chunk_overlap": self.data.chunk_overlap, + "reranker_model": self.data.reranker_model, + "top_k": self.data.top_k, + "batch_size": self.data.batch_size, + "document_paths": self.data.document_paths, + "files": files, + }); + let output = serde_yaml::to_string(&data) + .with_context(|| format!("Unable to show info about rag '{}'", self.name))?; + Ok(output) + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn is_temp(&self) -> bool { + self.name == TEMP_RAG_NAME + } + + pub async fn search( + &self, + text: &str, + top_k: usize, + rerank_model: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<(String, Vec)> { + let ret = abortable_run_with_spinner( + self.hybird_search(text, top_k, rerank_model), + "Searching", + abort_signal, + ) + .await; + let (ids, documents): (Vec<_>, Vec<_>) = ret?.into_iter().unzip(); + let embeddings = documents.join("\n\n"); + Ok((embeddings, ids)) + } + + pub async fn sync_documents( + &mut self, + paths: &[String], + refresh: bool, + loaders: HashMap, + spinner: Option, + ) -> Result<()> { + if let Some(spinner) = &spinner { + let _ = spinner.set_message(String::new()); + } + let (document_paths, mut recursive_urls, mut urls, mut protocol_paths, mut local_paths) = + resolve_paths(&loaders, paths).await?; + let mut to_deleted: IndexMap> = Default::default(); + if refresh { + for (file_id, file) in &self.data.files { + to_deleted + .entry(file.hash.clone()) + .or_default() + .push(*file_id); + } + } else { + let recursive_urls_cloned = recursive_urls.clone(); + let match_recursive_url = |v: &str| { + recursive_urls_cloned + .iter() + .any(|start_url| v.starts_with(start_url)) + }; + recursive_urls = recursive_urls + .into_iter() + .filter(|v| !self.data.document_paths.contains(&format!("{v}**"))) + .collect(); + let protocol_paths_cloned = protocol_paths.clone(); + let match_protocol_path = + |v: &str| protocol_paths_cloned.iter().any(|root| v.starts_with(root)); + protocol_paths = protocol_paths + .into_iter() + .filter(|v| !self.data.document_paths.contains(v)) + .collect(); + for (file_id, file) in &self.data.files { + if is_url(&file.path) { + if !urls.swap_remove(&file.path) && !match_recursive_url(&file.path) { + to_deleted + .entry(file.hash.clone()) + .or_default() + .push(*file_id); + } + } else if is_loader_protocol(&loaders, &file.path) { + if !match_protocol_path(&file.path) { + to_deleted + .entry(file.hash.clone()) + .or_default() + .push(*file_id); + } + } else if !local_paths.swap_remove(&file.path) { + to_deleted + .entry(file.hash.clone()) + .or_default() + .push(*file_id); + } + } + } + + let mut loaded_documents = vec![]; + let mut has_error = false; + let mut index = 0; + let total = recursive_urls.len() + urls.len() + protocol_paths.len() + local_paths.len(); + let handle_error = |error: anyhow::Error, has_error: &mut bool| { + println!("{}", warning_text(&format!("⚠️ {error}"))); + *has_error = true; + }; + for start_url in recursive_urls { + index += 1; + println!("Load {start_url}** [{index}/{total}]"); + match load_recursive_url(&loaders, &start_url).await { + Ok(v) => loaded_documents.extend(v), + Err(err) => handle_error(err, &mut has_error), + } + } + for url in urls { + index += 1; + println!("Load {url} [{index}/{total}]"); + match load_url(&loaders, &url).await { + Ok(v) => loaded_documents.push(v), + Err(err) => handle_error(err, &mut has_error), + } + } + for protocol_path in protocol_paths { + index += 1; + println!("Load {protocol_path} [{index}/{total}]"); + match load_protocol_path(&loaders, &protocol_path) { + Ok(v) => loaded_documents.extend(v), + Err(err) => handle_error(err, &mut has_error), + } + } + for local_path in local_paths { + index += 1; + println!("Load {local_path} [{index}/{total}]"); + match load_file(&loaders, &local_path).await { + Ok(v) => loaded_documents.push(v), + Err(err) => handle_error(err, &mut has_error), + } + } + + if has_error { + let mut aborted = true; + if *IS_STDOUT_TERMINAL && total > 0 { + let ans = Confirm::new("Some documents failed to load. Continue?") + .with_default(false) + .prompt()?; + aborted = !ans; + } + if aborted { + bail!("Aborted"); + } + } + + let mut rag_files = vec![]; + for LoadedDocument { + path, + contents, + mut metadata, + } in loaded_documents + { + let hash = sha256(&contents); + if let Some(file_ids) = to_deleted.get_mut(&hash) { + if let Some((i, _)) = file_ids + .iter() + .enumerate() + .find(|(_, v)| self.data.files[*v].path == path) + { + if file_ids.len() == 1 { + to_deleted.swap_remove(&hash); + } else { + file_ids.remove(i); + } + continue; + } + } + let extension = metadata + .swap_remove(EXTENSION_METADATA) + .unwrap_or_else(|| DEFAULT_EXTENSION.into()); + let separator = get_separators(&extension); + let splitter = RecursiveCharacterTextSplitter::new( + self.data.chunk_size, + self.data.chunk_overlap, + &separator, + ); + + let split_options = SplitterChunkHeaderOptions::default(); + let document = RagDocument::new(contents); + let split_documents = splitter.split_documents(&[document], &split_options); + rag_files.push(RagFile { + hash: hash.clone(), + path, + documents: split_documents, + }); + } + + let mut next_file_id = self.data.next_file_id; + let mut files = vec![]; + let mut document_ids = vec![]; + let mut embeddings = vec![]; + + if !rag_files.is_empty() { + let mut texts = vec![]; + for file in rag_files.into_iter() { + for (document_index, document) in file.documents.iter().enumerate() { + document_ids.push(DocumentId::new(next_file_id, document_index)); + texts.push(document.page_content.clone()) + } + files.push((next_file_id, file)); + next_file_id += 1; + } + + let embeddings_data = EmbeddingsData::new(texts, false); + embeddings = self + .create_embeddings(embeddings_data, spinner.clone()) + .await?; + } + + let to_delete_file_ids: Vec<_> = to_deleted.values().flatten().copied().collect(); + self.data.del(to_delete_file_ids); + self.data.add(next_file_id, files, document_ids, embeddings); + self.data.document_paths = document_paths.into_iter().collect(); + + if self.data.files.is_empty() { + bail!("No RAG files"); + } + + progress(&spinner, "Building store".into()); + self.hnsw = self.data.build_hnsw(); + self.bm25 = self.data.build_bm25(); + + Ok(()) + } + + async fn hybird_search( + &self, + query: &str, + top_k: usize, + rerank_model: Option<&str>, + ) -> Result> { + let (vector_search_results, keyword_search_results) = tokio::join!( + self.vector_search(query, top_k, 0.0), + self.keyword_search(query, top_k, 0.0), + ); + + let vector_search_results = vector_search_results?; + debug!("vector_search_results: {vector_search_results:?}",); + let vector_search_ids: Vec = + vector_search_results.into_iter().map(|(v, _)| v).collect(); + + let keyword_search_results = keyword_search_results?; + debug!("keyword_search_results: {keyword_search_results:?}",); + let keyword_search_ids: Vec = + keyword_search_results.into_iter().map(|(v, _)| v).collect(); + + let ids = match rerank_model { + Some(model_id) => { + let model = + Model::retrieve_model(&self.config.read(), model_id, ModelType::Reranker)?; + let client = init_client(&self.config, Some(model))?; + let ids: IndexSet = [vector_search_ids, keyword_search_ids] + .concat() + .into_iter() + .collect(); + let mut documents = vec![]; + let mut documents_ids = vec![]; + for id in ids { + if let Some(document) = self.data.get(id) { + documents_ids.push(id); + documents.push(document.page_content.to_string()); + } + } + let data = RerankData::new(query.to_string(), documents, top_k); + let list = client.rerank(&data).await.context("Failed to rerank")?; + let ids: Vec<_> = list + .into_iter() + .take(top_k) + .filter_map(|item| documents_ids.get(item.index).cloned()) + .collect(); + debug!("rerank_ids: {ids:?}"); + ids + } + None => { + let ids = reciprocal_rank_fusion( + vec![vector_search_ids, keyword_search_ids], + vec![1.125, 1.0], + top_k, + ); + debug!("rrf_ids: {ids:?}"); + ids + } + }; + let output = ids + .into_iter() + .filter_map(|id| { + let document = self.data.get(id)?; + Some((id, document.page_content.clone())) + }) + .collect(); + Ok(output) + } + + async fn vector_search( + &self, + query: &str, + top_k: usize, + min_score: f32, + ) -> Result> { + let splitter = RecursiveCharacterTextSplitter::new( + self.data.chunk_size, + self.data.chunk_overlap, + &DEFAULT_SEPARATORS, + ); + let texts = splitter.split_text(query); + let embeddings_data = EmbeddingsData::new(texts, true); + let embeddings = self.create_embeddings(embeddings_data, None).await?; + let output = self + .hnsw + .parallel_search(&embeddings, top_k, 30) + .into_iter() + .flat_map(|list| { + list.into_iter() + .filter_map(|v| { + let score = 1.0 - v.distance; + if score > min_score { + Some((DocumentId(v.d_id), score)) + } else { + None + } + }) + .collect::>() + }) + .collect(); + Ok(output) + } + + async fn keyword_search( + &self, + query: &str, + top_k: usize, + min_score: f32, + ) -> Result> { + let results = self.bm25.search(query, top_k); + let output: Vec<(DocumentId, f32)> = results + .into_iter() + .filter_map(|v| { + let score = v.score; + if score > min_score { + Some((v.document.id, score)) + } else { + None + } + }) + .collect(); + Ok(output) + } + + async fn create_embeddings( + &self, + data: EmbeddingsData, + spinner: Option, + ) -> Result { + let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?; + let EmbeddingsData { texts, query } = data; + let batch_size = self + .data + .batch_size + .or_else(|| self.embedding_model.max_batch_size()); + let batch_size = match self.embedding_model.max_input_tokens() { + Some(max_input_tokens) => { + let x = max_input_tokens / self.data.chunk_size; + match batch_size { + Some(y) => x.min(y), + None => x, + } + } + None => batch_size.unwrap_or(1), + }; + let mut output = vec![]; + let batch_chunks = texts.chunks(batch_size.max(1)); + let batch_chunks_len = batch_chunks.len(); + let retry_limit = env::var(get_env_name("embeddings_retry_limit")) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(2); + for (index, texts) in batch_chunks.enumerate() { + progress( + &spinner, + format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1), + ); + let chunk_data = EmbeddingsData { + texts: texts.to_vec(), + query, + }; + let mut retry = 0; + let chunk_output = loop { + retry += 1; + match embedding_client.embeddings(&chunk_data).await { + Ok(v) => break v, + Err(e) if retry < retry_limit => { + debug!("retry {retry} failed: {e}"); + sleep(Duration::from_secs(2u64.pow(retry - 1))).await; + continue; + } + Err(e) => { + return Err(e).with_context(|| { + format!("Failed to create embedding after {retry_limit} attempts") + })? + } + } + }; + output.extend(chunk_output); + } + Ok(output) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct RagData { + pub embedding_model: String, + pub chunk_size: usize, + pub chunk_overlap: usize, + pub reranker_model: Option, + pub top_k: usize, + pub batch_size: Option, + pub next_file_id: FileId, + pub document_paths: Vec, + pub files: IndexMap, + #[serde(with = "serde_vectors")] + pub vectors: IndexMap>, +} + +impl Debug for RagData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RagData") + .field("embedding_model", &self.embedding_model) + .field("chunk_size", &self.chunk_size) + .field("chunk_overlap", &self.chunk_overlap) + .field("reranker_model", &self.reranker_model) + .field("top_k", &self.top_k) + .field("batch_size", &self.batch_size) + .field("next_file_id", &self.next_file_id) + .field("document_paths", &self.document_paths) + .field("files", &self.files) + .finish() + } +} + +impl RagData { + pub fn new( + embedding_model: String, + chunk_size: usize, + chunk_overlap: usize, + reranker_model: Option, + top_k: usize, + batch_size: Option, + ) -> Self { + Self { + embedding_model, + chunk_size, + chunk_overlap, + reranker_model, + top_k, + batch_size, + next_file_id: 0, + document_paths: Default::default(), + files: Default::default(), + vectors: Default::default(), + } + } + + pub fn get(&self, id: DocumentId) -> Option<&RagDocument> { + let (file_index, document_index) = id.split(); + let file = self.files.get(&file_index)?; + let document = file.documents.get(document_index)?; + Some(document) + } + + pub fn del(&mut self, file_ids: Vec) { + for file_id in file_ids { + if let Some(file) = self.files.swap_remove(&file_id) { + for (document_index, _) in file.documents.iter().enumerate() { + let document_id = DocumentId::new(file_id, document_index); + self.vectors.swap_remove(&document_id); + } + } + } + } + + pub fn add( + &mut self, + next_file_id: FileId, + files: Vec<(FileId, RagFile)>, + document_ids: Vec, + embeddings: EmbeddingsOutput, + ) { + self.next_file_id = next_file_id; + self.files.extend(files); + self.vectors + .extend(document_ids.into_iter().zip(embeddings)); + } + + pub fn build_hnsw(&self) -> Hnsw<'static, f32, DistCosine> { + let hnsw = Hnsw::new(32, self.vectors.len(), 16, 200, DistCosine {}); + let list: Vec<_> = self.vectors.iter().map(|(k, v)| (v, k.0)).collect(); + hnsw.parallel_insert(&list); + hnsw + } + + pub fn build_bm25(&self) -> SearchEngine { + let mut documents = vec![]; + for (file_index, file) in self.files.iter() { + for (document_index, document) in file.documents.iter().enumerate() { + let id = DocumentId::new(*file_index, document_index); + documents.push(bm25::Document::new(id, &document.page_content)) + } + } + SearchEngineBuilder::::with_documents(Language::English, documents) + .k1(1.5) + .b(0.75) + .build() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RagFile { + hash: String, + path: String, + documents: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RagDocument { + pub page_content: String, + pub metadata: DocumentMetadata, +} + +impl RagDocument { + pub fn new>(page_content: S) -> Self { + RagDocument { + page_content: page_content.into(), + metadata: IndexMap::new(), + } + } +} + +impl Default for RagDocument { + fn default() -> Self { + RagDocument { + page_content: "".to_string(), + metadata: IndexMap::new(), + } + } +} + +pub type FileId = usize; + +#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub struct DocumentId(usize); + +impl Debug for DocumentId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (file_index, document_index) = self.split(); + f.write_fmt(format_args!("{file_index}-{document_index}")) + } +} + +impl DocumentId { + pub fn new(file_index: usize, document_index: usize) -> Self { + let value = (file_index << (usize::BITS / 2)) | document_index; + Self(value) + } + + pub fn split(self) -> (usize, usize) { + let value = self.0; + let low_mask = (1 << (usize::BITS / 2)) - 1; + let low = value & low_mask; + let high = value >> (usize::BITS / 2); + (high, low) + } +} + +fn select_embedding_model(models: &[&Model]) -> Result { + let models: Vec<_> = models + .iter() + .map(|v| SelectOption::new(v.id(), v.description())) + .collect(); + let result = Select::new("Select embedding model:", models).prompt()?; + Ok(result.value) +} + +#[derive(Debug)] +struct SelectOption { + pub value: String, + pub description: String, +} + +impl SelectOption { + pub fn new(value: String, description: String) -> Self { + Self { value, description } + } +} + +impl std::fmt::Display for SelectOption { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} ({})", self.value, self.description) + } +} + +fn set_chunk_size(model: &Model) -> Result { + let default_value = model.default_chunk_size().to_string(); + let help_message = model + .max_tokens_per_chunk() + .map(|v| format!("The model's max_tokens is {v}")); + + let mut text = Text::new("Set chunk size:") + .with_default(&default_value) + .with_validator(move |text: &str| { + let out = match text.parse::() { + Ok(_) => Validation::Valid, + Err(_) => Validation::Invalid("Must be a integer".into()), + }; + Ok(out) + }); + if let Some(help_message) = &help_message { + text = text.with_help_message(help_message); + } + let value = text.prompt()?; + value.parse().map_err(|_| anyhow!("Invalid chunk_size")) +} + +fn set_chunk_overlay(default_value: usize) -> Result { + let value = Text::new("Set chunk overlay:") + .with_default(&default_value.to_string()) + .with_validator(move |text: &str| { + let out = match text.parse::() { + Ok(_) => Validation::Valid, + Err(_) => Validation::Invalid("Must be a integer".into()), + }; + Ok(out) + }) + .prompt()?; + value.parse().map_err(|_| anyhow!("Invalid chunk_overlay")) +} + +fn add_documents() -> Result> { + let text = Text::new("Add documents:") + .with_validator(required!("This field is required")) + .with_help_message("e.g. file;dir/;dir/**/*.{md,mdx};loader:resource;url;website/**") + .prompt()?; + let paths = text + .split(';') + .filter_map(|v| { + let v = v.trim().to_string(); + if v.is_empty() { + None + } else { + Some(v) + } + }) + .collect(); + Ok(paths) +} + +async fn resolve_paths>( + loaders: &HashMap, + paths: &[T], +) -> Result<( + IndexSet, + IndexSet, + IndexSet, + IndexSet, + IndexSet, +)> { + let mut document_paths = IndexSet::new(); + let mut recursive_urls = IndexSet::new(); + let mut urls = IndexSet::new(); + let mut protocol_paths = IndexSet::new(); + let mut absolute_paths = vec![]; + for path in paths { + let path = path.as_ref().trim(); + if is_url(path) { + if let Some(start_url) = path.strip_suffix("**") { + recursive_urls.insert(start_url.to_string()); + } else { + urls.insert(path.to_string()); + } + document_paths.insert(path.to_string()); + } else if is_loader_protocol(loaders, path) { + protocol_paths.insert(path.to_string()); + document_paths.insert(path.to_string()); + } else { + let resolved_path = resolve_home_dir(path); + let absolute_path = to_absolute_path(&resolved_path) + .with_context(|| format!("Invalid path '{path}'"))?; + absolute_paths.push(resolved_path); + document_paths.insert(absolute_path); + } + } + let local_paths = expand_glob_paths(&absolute_paths, false).await?; + Ok(( + document_paths, + recursive_urls, + urls, + protocol_paths, + local_paths, + )) +} + +fn progress(spinner: &Option, message: String) { + if let Some(spinner) = spinner { + let _ = spinner.set_message(message); + } +} + +fn reciprocal_rank_fusion( + list_of_document_ids: Vec>, + list_of_weights: Vec, + top_k: usize, +) -> Vec { + let rrf_k = top_k * 2; + let mut map: IndexMap = IndexMap::new(); + for (document_ids, weight) in list_of_document_ids + .into_iter() + .zip(list_of_weights.into_iter()) + { + for (index, &item) in document_ids.iter().enumerate() { + *map.entry(item).or_default() += (1.0 / ((rrf_k + index + 1) as f32)) * weight; + } + } + let mut sorted_items: Vec<(DocumentId, f32)> = map.into_iter().collect(); + sorted_items.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + sorted_items + .into_iter() + .take(top_k) + .map(|(v, _)| v) + .collect() +} diff --git a/src/rag/serde_vectors.rs b/src/rag/serde_vectors.rs new file mode 100644 index 0000000..1f66230 --- /dev/null +++ b/src/rag/serde_vectors.rs @@ -0,0 +1,66 @@ +use super::*; + +use base64::{engine::general_purpose::STANDARD, Engine}; +use serde::{de, Deserializer, Serializer}; + +pub fn serialize( + vectors: &IndexMap>, + serializer: S, +) -> Result +where + S: Serializer, +{ + let encoded_map: IndexMap = vectors + .iter() + .map(|(id, vec)| { + let (h, l) = id.split(); + let byte_slice = unsafe { + std::slice::from_raw_parts(vec.as_ptr() as *const u8, vec.len() * size_of::()) + }; + (format!("{h}-{l}"), STANDARD.encode(byte_slice)) + }) + .collect(); + + encoded_map.serialize(serializer) +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let encoded_map: IndexMap = + IndexMap::::deserialize(deserializer)?; + + let mut decoded_map = IndexMap::new(); + for (key, base64_str) in encoded_map { + let decoded_key: DocumentId = key + .split_once('-') + .and_then(|(h, l)| { + let h = h.parse::().ok()?; + let l = l.parse::().ok()?; + Some(DocumentId::new(h, l)) + }) + .ok_or_else(|| de::Error::custom(format!("Invalid key '{key}'")))?; + + let decoded_data = STANDARD.decode(&base64_str).map_err(de::Error::custom)?; + + if decoded_data.len() % size_of::() != 0 { + return Err(de::Error::custom(format!("Invalid vector at '{key}'"))); + } + + let num_f32s = decoded_data.len() / size_of::(); + + let mut vec_f32 = vec![0.0f32; num_f32s]; + unsafe { + std::ptr::copy_nonoverlapping( + decoded_data.as_ptr(), + vec_f32.as_mut_ptr() as *mut u8, + decoded_data.len(), + ); + } + + decoded_map.insert(decoded_key, vec_f32); + } + + Ok(decoded_map) +} diff --git a/src/rag/splitter/language.rs b/src/rag/splitter/language.rs new file mode 100644 index 0000000..20722cf --- /dev/null +++ b/src/rag/splitter/language.rs @@ -0,0 +1,235 @@ +#[derive(PartialEq, Eq, Hash)] +pub enum Language { + Cpp, + Go, + Java, + Js, + Php, + Proto, + Python, + Rst, + Ruby, + Rust, + Scala, + Swift, + Markdown, + Latex, + Html, + Sol, +} + +impl Language { + pub fn separators(&self) -> Vec<&str> { + match self { + Language::Cpp => vec![ + "\nclass ", + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Go => vec![ + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Java => vec![ + "\nclass ", + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Js => vec![ + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + "\n\n", + "\n", + " ", + "", + ], + Language::Php => vec![ + "\nfunction ", + "\nclass ", + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Proto => vec![ + "\nmessage ", + "\nservice ", + "\nenum ", + "\noption ", + "\nimport ", + "\nsyntax ", + "\n\n", + "\n", + " ", + "", + ], + Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""], + Language::Rst => vec![ + "\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "", + ], + Language::Ruby => vec![ + "\ndef ", + "\nclass ", + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + "\n\n", + "\n", + " ", + "", + ], + Language::Rust => vec![ + "\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ", + "\nmatch ", "\nconst ", "\n\n", "\n", " ", "", + ], + Language::Scala => vec![ + "\nclass ", + "\nobject ", + "\ndef ", + "\nval ", + "\nvar ", + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Swift => vec![ + "\nfunc ", + "\nclass ", + "\nstruct ", + "\nenum ", + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + "\n\n", + "\n", + " ", + "", + ], + Language::Markdown => vec![ + "\n## ", + "\n### ", + "\n#### ", + "\n##### ", + "\n###### ", + "```\n\n", + "\n\n***\n\n", + "\n\n---\n\n", + "\n\n___\n\n", + "\n\n", + "\n", + " ", + "", + ], + Language::Latex => vec![ + "\n\\chapter{", + "\n\\section{", + "\n\\subsection{", + "\n\\subsubsection{", + "\n\\begin{enumerate}", + "\n\\begin{itemize}", + "\n\\begin{description}", + "\n\\begin{list}", + "\n\\begin{quote}", + "\n\\begin{quotation}", + "\n\\begin{verse}", + "\n\\begin{verbatim}", + "\n\\begin{align}", + "$$", + "$", + "\n\n", + "\n", + " ", + "", + ], + Language::Html => vec![ + "", "
", "

", "
", "

  • ", "

    ", "

    ", "

    ", "

    ", "

    ", + "
    ", "", "", "", "
    ", "", "
      ", "
        ", "
        ", + "