From 25ad254e84aee6ee448ee011c36565c9d4c25ec6 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 16 Oct 2025 13:01:37 -0600 Subject: [PATCH] style: Applied formatting --- src/config/mod.rs | 5502 ++++++++++++++++++++++---------------------- src/function.rs | 2 +- src/mcp/mod.rs | 452 ++-- src/vault/mod.rs | 180 +- src/vault/utils.rs | 218 +- 5 files changed, 3177 insertions(+), 3177 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 741a3dd..6961424 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,14 +6,14 @@ mod session; pub use self::agent::{complete_agent_variables, list_agents, Agent, AgentVariables}; pub use self::input::Input; pub use self::role::{ - Role, RoleLike, CODE_ROLE, CREATE_TITLE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, + Role, RoleLike, CODE_ROLE, CREATE_TITLE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, }; use self::session::Session; use mem::take; use crate::client::{ - create_client_config, list_client_types, list_models, ClientConfig, MessageContentToolCalls, - Model, ModelType, ProviderModels, OPENAI_COMPATIBLE_PROVIDERS, + create_client_config, list_client_types, list_models, ClientConfig, MessageContentToolCalls, + Model, ModelType, ProviderModels, OPENAI_COMPATIBLE_PROVIDERS, }; use crate::function::{FunctionDeclaration, Functions, ToolResult}; use crate::rag::Rag; @@ -22,7 +22,7 @@ use crate::repl::{run_repl_command, split_args_text}; use crate::utils::*; use crate::mcp::{ - McpRegistry, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX, + McpRegistry, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX, }; use crate::vault::{interpolate_secrets, Vault}; use anyhow::{anyhow, bail, Context, Result}; @@ -37,15 +37,15 @@ use serde_json::json; use std::collections::{HashMap, HashSet}; use std::sync::LazyLock; use std::{ - env, - fs::{ - create_dir_all, read_dir, read_to_string, remove_dir_all, remove_file, File, OpenOptions, - }, - io::Write, - mem, - path::{Path, PathBuf}, - process, - sync::{Arc, OnceLock}, + env, + fs::{ + create_dir_all, read_dir, read_to_string, remove_dir_all, remove_file, File, OpenOptions, + }, + io::Write, + mem, + path::{Path, PathBuf}, + process, + sync::{Arc, OnceLock}, }; use syntect::highlighting::ThemeSet; use terminal_colorsaurus::{color_scheme, ColorScheme, QueryOptions}; @@ -56,7 +56,7 @@ pub const TEMP_RAG_NAME: &str = "temp"; pub const TEMP_SESSION_NAME: &str = "temp"; static PASSWORD_FILE_SECRET_RE: LazyLock = - LazyLock::new(|| Regex::new(r#"vault_password_file:.*['|"]?\{\{(.+)}}['|"]?"#).unwrap()); + LazyLock::new(|| Regex::new(r#"vault_password_file:.*['|"]?\{\{(.+)}}['|"]?"#).unwrap()); /// Monokai Extended const DARK_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bin"); @@ -81,10 +81,10 @@ const CLIENTS_FIELD: &str = "clients"; const SERVE_ADDR: &str = "127.0.0.1:8000"; const SYNC_MODELS_URL: &str = - "https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml"; + "https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml"; const SUMMARIZE_PROMPT: &str = - "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."; + "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."; const SUMMARY_PROMPT: &str = "This is a summary of the chat history as a recap: "; const RAG_TEMPLATE: &str = r#"Answer the query based on the context while respecting the rules. (user query, some textual context and rules, all inside xml tags) @@ -114,2022 +114,2022 @@ static EDITOR: OnceLock> = OnceLock::new(); #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { - #[serde(rename(serialize = "model", deserialize = "model"))] - #[serde(default)] - pub model_id: String, - pub temperature: Option, - pub top_p: Option, + #[serde(rename(serialize = "model", deserialize = "model"))] + #[serde(default)] + pub model_id: String, + pub temperature: Option, + pub top_p: Option, - pub dry_run: bool, - pub stream: bool, - pub save: bool, - pub keybindings: String, - pub editor: Option, - pub wrap: Option, - pub wrap_code: bool, - vault_password_file: Option, + pub dry_run: bool, + pub stream: bool, + pub save: bool, + pub keybindings: String, + pub editor: Option, + pub wrap: Option, + pub wrap_code: bool, + vault_password_file: Option, - pub function_calling: bool, - pub mapping_tools: IndexMap, - pub use_tools: Option, + pub function_calling: bool, + pub mapping_tools: IndexMap, + pub use_tools: Option, - pub mcp_servers: bool, - pub mapping_mcp_servers: IndexMap, - pub use_mcp_servers: Option, + pub mcp_servers: bool, + pub mapping_mcp_servers: IndexMap, + pub use_mcp_servers: Option, - pub repl_prelude: Option, - pub cmd_prelude: Option, - pub agent_session: Option, + pub repl_prelude: Option, + pub cmd_prelude: Option, + pub agent_session: Option, - pub save_session: Option, - pub compress_threshold: usize, - pub summarize_prompt: Option, - pub summary_prompt: Option, + pub save_session: Option, + pub compress_threshold: usize, + pub summarize_prompt: Option, + pub summary_prompt: Option, - pub rag_embedding_model: Option, - pub rag_reranker_model: Option, - pub rag_top_k: usize, - pub rag_chunk_size: Option, - pub rag_chunk_overlap: Option, - pub rag_template: Option, + pub rag_embedding_model: Option, + pub rag_reranker_model: Option, + pub rag_top_k: usize, + pub rag_chunk_size: Option, + pub rag_chunk_overlap: Option, + pub rag_template: Option, - #[serde(default)] - pub document_loaders: HashMap, + #[serde(default)] + pub document_loaders: HashMap, - pub highlight: bool, - pub theme: Option, - pub left_prompt: Option, - pub right_prompt: Option, + pub highlight: bool, + pub theme: Option, + pub left_prompt: Option, + pub right_prompt: Option, - pub serve_addr: Option, - pub user_agent: Option, - pub save_shell_history: bool, - pub sync_models_url: Option, + pub serve_addr: Option, + pub user_agent: Option, + pub save_shell_history: bool, + pub sync_models_url: Option, - pub clients: Vec, + pub clients: Vec, - #[serde(skip)] - pub vault: Vault, + #[serde(skip)] + pub vault: Vault, - #[serde(skip)] - pub macro_flag: bool, - #[serde(skip)] - pub info_flag: bool, - #[serde(skip)] - pub agent_variables: Option, + #[serde(skip)] + pub macro_flag: bool, + #[serde(skip)] + pub info_flag: bool, + #[serde(skip)] + pub agent_variables: Option, - #[serde(skip)] - pub model: Model, - #[serde(skip)] - pub functions: Functions, - #[serde(skip)] - pub mcp_registry: Option, - #[serde(skip)] - pub working_mode: WorkingMode, - #[serde(skip)] - pub last_message: Option, + #[serde(skip)] + pub model: Model, + #[serde(skip)] + pub functions: Functions, + #[serde(skip)] + pub mcp_registry: Option, + #[serde(skip)] + pub working_mode: WorkingMode, + #[serde(skip)] + pub last_message: Option, - #[serde(skip)] - pub role: Option, - #[serde(skip)] - pub session: Option, - #[serde(skip)] - pub rag: Option>, - #[serde(skip)] - pub agent: Option, + #[serde(skip)] + pub role: Option, + #[serde(skip)] + pub session: Option, + #[serde(skip)] + pub rag: Option>, + #[serde(skip)] + pub agent: Option, } impl Default for Config { - fn default() -> Self { - Self { - model_id: Default::default(), - temperature: None, - top_p: None, + fn default() -> Self { + Self { + model_id: Default::default(), + temperature: None, + top_p: None, - dry_run: false, - stream: true, - save: false, - keybindings: "emacs".into(), - editor: None, - wrap: None, - wrap_code: false, - vault_password_file: None, + dry_run: false, + stream: true, + save: false, + keybindings: "emacs".into(), + editor: None, + wrap: None, + wrap_code: false, + vault_password_file: None, - function_calling: true, - mapping_tools: Default::default(), - use_tools: None, + function_calling: true, + mapping_tools: Default::default(), + use_tools: None, - mcp_servers: true, - mapping_mcp_servers: Default::default(), - use_mcp_servers: None, + mcp_servers: true, + mapping_mcp_servers: Default::default(), + use_mcp_servers: None, - repl_prelude: None, - cmd_prelude: None, - agent_session: None, + repl_prelude: None, + cmd_prelude: None, + agent_session: None, - save_session: None, - compress_threshold: 4000, - summarize_prompt: None, - summary_prompt: None, + save_session: None, + compress_threshold: 4000, + summarize_prompt: None, + summary_prompt: None, - rag_embedding_model: None, - rag_reranker_model: None, - rag_top_k: 5, - rag_chunk_size: None, - rag_chunk_overlap: None, - rag_template: None, + rag_embedding_model: None, + rag_reranker_model: None, + rag_top_k: 5, + rag_chunk_size: None, + rag_chunk_overlap: None, + rag_template: None, - document_loaders: Default::default(), + document_loaders: Default::default(), - highlight: true, - theme: None, - left_prompt: None, - right_prompt: None, + highlight: true, + theme: None, + left_prompt: None, + right_prompt: None, - serve_addr: None, - user_agent: None, - save_shell_history: true, - sync_models_url: None, + serve_addr: None, + user_agent: None, + save_shell_history: true, + sync_models_url: None, - clients: vec![], + clients: vec![], - vault: Default::default(), + vault: Default::default(), - macro_flag: false, - info_flag: false, - agent_variables: None, + macro_flag: false, + info_flag: false, + agent_variables: None, - model: Default::default(), - functions: Default::default(), - mcp_registry: Default::default(), - working_mode: WorkingMode::Cmd, - last_message: None, + model: Default::default(), + functions: Default::default(), + mcp_registry: Default::default(), + working_mode: WorkingMode::Cmd, + last_message: None, - role: None, - session: None, - rag: None, - agent: None, - } - } + role: None, + session: None, + rag: None, + agent: None, + } + } } pub type GlobalConfig = Arc>; impl Config { - pub fn init_bare() -> Result { - let h = Handle::current(); - tokio::task::block_in_place(|| { - h.block_on(Self::init( - WorkingMode::Cmd, - true, - false, - None, - create_abort_signal(), - )) - }) - } + pub fn init_bare() -> Result { + let h = Handle::current(); + tokio::task::block_in_place(|| { + h.block_on(Self::init( + WorkingMode::Cmd, + true, + false, + None, + create_abort_signal(), + )) + }) + } - pub async fn init( - working_mode: WorkingMode, - info_flag: bool, - start_mcp_servers: bool, - log_path: Option, - abort_signal: AbortSignal, - ) -> Result { - let config_path = Self::config_file(); - let (mut config, content) = if !config_path.exists() { - match env::var(get_env_name("provider")) - .ok() - .or_else(|| env::var(get_env_name("platform")).ok()) - { - Some(v) => (Self::load_dynamic(&v)?, String::new()), - None => { - if *IS_STDOUT_TERMINAL { - create_config_file(&config_path).await?; - } - Self::load_from_file(&config_path)? - } - } - } else { - Self::load_from_file(&config_path)? - }; + pub async fn init( + working_mode: WorkingMode, + info_flag: bool, + start_mcp_servers: bool, + log_path: Option, + abort_signal: AbortSignal, + ) -> Result { + let config_path = Self::config_file(); + let (mut config, content) = if !config_path.exists() { + match env::var(get_env_name("provider")) + .ok() + .or_else(|| env::var(get_env_name("platform")).ok()) + { + Some(v) => (Self::load_dynamic(&v)?, String::new()), + None => { + if *IS_STDOUT_TERMINAL { + create_config_file(&config_path).await?; + } + Self::load_from_file(&config_path)? + } + } + } else { + Self::load_from_file(&config_path)? + }; - let setup = async |config: &mut Self| -> Result<()> { - let vault = Vault::init(config); + let setup = async |config: &mut Self| -> Result<()> { + let vault = Vault::init(config); - let (parsed_config, missing_secrets) = interpolate_secrets(&content, &vault); - if !missing_secrets.is_empty() && !info_flag { - debug!("Global config references secrets that are missing from the vault: {missing_secrets:?}"); - return Err(anyhow!(formatdoc!( + let (parsed_config, missing_secrets) = interpolate_secrets(&content, &vault); + if !missing_secrets.is_empty() && !info_flag { + debug!("Global config references secrets that are missing from the vault: {missing_secrets:?}"); + return Err(anyhow!(formatdoc!( " Global config file references secrets that are missing from the vault: {:?} Please add these secrets to the vault and try again.", missing_secrets ))); - } + } - if !parsed_config.is_empty() && !info_flag { - debug!("Global config is invalid once secrets are injected: {parsed_config}"); - let new_config = Self::load_from_str(&parsed_config).with_context(|| { - formatdoc!( + if !parsed_config.is_empty() && !info_flag { + debug!("Global config is invalid once secrets are injected: {parsed_config}"); + let new_config = Self::load_from_str(&parsed_config).with_context(|| { + formatdoc!( " Global config is invalid once secrets are injected. Double check the secret values and file syntax, then try again. " ) - })?; - *config = new_config.clone(); - } + })?; + *config = new_config.clone(); + } - config.working_mode = working_mode; - config.info_flag = info_flag; - config.vault = vault; + config.working_mode = working_mode; + config.info_flag = info_flag; + config.vault = vault; - Agent::install_builtin_agents()?; + Agent::install_builtin_agents()?; - config.load_envs(); + config.load_envs(); - if let Some(wrap) = config.wrap.clone() { - config.set_wrap(&wrap)?; - } + if let Some(wrap) = config.wrap.clone() { + config.set_wrap(&wrap)?; + } - config.load_functions()?; - config - .load_mcp_servers(log_path, start_mcp_servers, abort_signal) - .await?; + config.load_functions()?; + config + .load_mcp_servers(log_path, start_mcp_servers, abort_signal) + .await?; - config.setup_model()?; - config.setup_document_loaders(); - config.setup_user_agent(); - Ok(()) - }; - let ret = setup(&mut config).await; - if !info_flag { - ret?; - } - Ok(config) - } + config.setup_model()?; + config.setup_document_loaders(); + config.setup_user_agent(); + Ok(()) + }; + let ret = setup(&mut config).await; + if !info_flag { + ret?; + } + Ok(config) + } - pub fn config_dir() -> PathBuf { - if let Ok(v) = env::var(get_env_name("config_dir")) { - PathBuf::from(v) - } else if let Ok(v) = env::var("XDG_CONFIG_HOME") { - PathBuf::from(v).join(env!("CARGO_CRATE_NAME")) - } else { - let dir = dirs::config_dir().expect("No user's config directory"); - dir.join(env!("CARGO_CRATE_NAME")) - } - } + pub fn config_dir() -> PathBuf { + if let Ok(v) = env::var(get_env_name("config_dir")) { + PathBuf::from(v) + } else if let Ok(v) = env::var("XDG_CONFIG_HOME") { + PathBuf::from(v).join(env!("CARGO_CRATE_NAME")) + } else { + let dir = dirs::config_dir().expect("No user's config directory"); + dir.join(env!("CARGO_CRATE_NAME")) + } + } - pub fn local_path(name: &str) -> PathBuf { - Self::config_dir().join(name) - } + pub fn local_path(name: &str) -> PathBuf { + Self::config_dir().join(name) + } - pub fn cache_path() -> PathBuf { - let base_dir = dirs::cache_dir().unwrap_or_else(env::temp_dir); + pub fn cache_path() -> PathBuf { + let base_dir = dirs::cache_dir().unwrap_or_else(env::temp_dir); - base_dir.join(env!("CARGO_CRATE_NAME")) - } + base_dir.join(env!("CARGO_CRATE_NAME")) + } - pub fn log_path() -> PathBuf { - Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME"))) - } + pub fn log_path() -> PathBuf { + Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME"))) + } - pub fn config_file() -> PathBuf { - match env::var(get_env_name("config_file")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(CONFIG_FILE_NAME), - } - } + pub fn config_file() -> PathBuf { + match env::var(get_env_name("config_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(CONFIG_FILE_NAME), + } + } - pub fn vault_password_file(&self) -> PathBuf { - match &self.vault_password_file { - Some(path) => match path.exists() { - true => path.clone(), - false => gman::config::Config::local_provider_password_file(), - }, - None => gman::config::Config::local_provider_password_file(), - } - } + pub fn vault_password_file(&self) -> PathBuf { + match &self.vault_password_file { + Some(path) => match path.exists() { + true => path.clone(), + false => gman::config::Config::local_provider_password_file(), + }, + None => gman::config::Config::local_provider_password_file(), + } + } - pub fn roles_dir() -> PathBuf { - match env::var(get_env_name("roles_dir")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(ROLES_DIR_NAME), - } - } + pub fn roles_dir() -> PathBuf { + match env::var(get_env_name("roles_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(ROLES_DIR_NAME), + } + } - pub fn role_file(name: &str) -> PathBuf { - Self::roles_dir().join(format!("{name}.md")) - } + pub fn role_file(name: &str) -> PathBuf { + Self::roles_dir().join(format!("{name}.md")) + } - pub fn macros_dir() -> PathBuf { - match env::var(get_env_name("macros_dir")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(MACROS_DIR_NAME), - } - } + pub fn macros_dir() -> PathBuf { + match env::var(get_env_name("macros_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(MACROS_DIR_NAME), + } + } - pub fn macro_file(name: &str) -> PathBuf { - Self::macros_dir().join(format!("{name}.yaml")) - } + pub fn macro_file(name: &str) -> PathBuf { + Self::macros_dir().join(format!("{name}.yaml")) + } - pub fn env_file() -> PathBuf { - match env::var(get_env_name("env_file")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(ENV_FILE_NAME), - } - } + pub fn env_file() -> PathBuf { + match env::var(get_env_name("env_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(ENV_FILE_NAME), + } + } - pub fn messages_file(&self) -> PathBuf { - match &self.agent { - None => match env::var(get_env_name("messages_file")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::cache_path().join(MESSAGES_FILE_NAME), - }, - Some(agent) => Self::cache_path() - .join(AGENTS_DIR_NAME) - .join(agent.name()) - .join(MESSAGES_FILE_NAME), - } - } + pub fn messages_file(&self) -> PathBuf { + match &self.agent { + None => match env::var(get_env_name("messages_file")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::cache_path().join(MESSAGES_FILE_NAME), + }, + Some(agent) => Self::cache_path() + .join(AGENTS_DIR_NAME) + .join(agent.name()) + .join(MESSAGES_FILE_NAME), + } + } - pub fn sessions_dir(&self) -> PathBuf { - match &self.agent { - None => match env::var(get_env_name("sessions_dir")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(SESSIONS_DIR_NAME), - }, - Some(agent) => Self::agent_data_dir(agent.name()).join(SESSIONS_DIR_NAME), - } - } + pub fn sessions_dir(&self) -> PathBuf { + match &self.agent { + None => match env::var(get_env_name("sessions_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(SESSIONS_DIR_NAME), + }, + Some(agent) => Self::agent_data_dir(agent.name()).join(SESSIONS_DIR_NAME), + } + } - pub fn rags_dir() -> PathBuf { - match env::var(get_env_name("rags_dir")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(RAGS_DIR_NAME), - } - } + pub fn rags_dir() -> PathBuf { + match env::var(get_env_name("rags_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(RAGS_DIR_NAME), + } + } - pub fn functions_dir() -> PathBuf { - match env::var(get_env_name("functions_dir")) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::local_path(FUNCTIONS_DIR_NAME), - } - } + pub fn functions_dir() -> PathBuf { + match env::var(get_env_name("functions_dir")) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::local_path(FUNCTIONS_DIR_NAME), + } + } - pub fn functions_bin_dir() -> PathBuf { - Self::functions_dir().join(FUNCTIONS_BIN_DIR_NAME) - } + pub fn functions_bin_dir() -> PathBuf { + Self::functions_dir().join(FUNCTIONS_BIN_DIR_NAME) + } - pub fn mcp_config_file() -> PathBuf { - Self::functions_dir().join(MCP_FILE_NAME) - } + pub fn mcp_config_file() -> PathBuf { + Self::functions_dir().join(MCP_FILE_NAME) + } - pub fn global_tools_file() -> PathBuf { - Self::functions_dir().join(GLOBAL_TOOLS_FILE_NAME) - } + pub fn global_tools_file() -> PathBuf { + Self::functions_dir().join(GLOBAL_TOOLS_FILE_NAME) + } - pub fn global_tools_dir() -> PathBuf { - Self::functions_dir().join(GLOBAL_TOOLS_DIR_NAME) - } + pub fn global_tools_dir() -> PathBuf { + Self::functions_dir().join(GLOBAL_TOOLS_DIR_NAME) + } - pub fn session_file(&self, name: &str) -> PathBuf { - match name.split_once("/") { - Some((dir, name)) => self.sessions_dir().join(dir).join(format!("{name}.yaml")), - None => self.sessions_dir().join(format!("{name}.yaml")), - } - } + pub fn session_file(&self, name: &str) -> PathBuf { + match name.split_once("/") { + Some((dir, name)) => self.sessions_dir().join(dir).join(format!("{name}.yaml")), + None => self.sessions_dir().join(format!("{name}.yaml")), + } + } - pub fn rag_file(&self, name: &str) -> PathBuf { - match &self.agent { - Some(agent) => Self::agent_rag_file(agent.name(), name), - None => Self::rags_dir().join(format!("{name}.yaml")), - } - } + pub fn rag_file(&self, name: &str) -> PathBuf { + match &self.agent { + Some(agent) => Self::agent_rag_file(agent.name(), name), + None => Self::rags_dir().join(format!("{name}.yaml")), + } + } - pub fn agents_data_dir() -> PathBuf { - Self::local_path(AGENTS_DIR_NAME) - } + pub fn agents_data_dir() -> PathBuf { + Self::local_path(AGENTS_DIR_NAME) + } - pub fn agent_data_dir(name: &str) -> PathBuf { - match env::var(format!("{}_DATA_DIR", normalize_env_name(name))) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::agents_data_dir().join(name), - } - } + pub fn agent_data_dir(name: &str) -> PathBuf { + match env::var(format!("{}_DATA_DIR", normalize_env_name(name))) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::agents_data_dir().join(name), + } + } - pub fn agent_config_file(name: &str) -> PathBuf { - match env::var(format!("{}_CONFIG_FILE", normalize_env_name(name))) { - Ok(value) => PathBuf::from(value), - Err(_) => Self::agent_data_dir(name).join(CONFIG_FILE_NAME), - } - } + pub fn agent_config_file(name: &str) -> PathBuf { + match env::var(format!("{}_CONFIG_FILE", normalize_env_name(name))) { + Ok(value) => PathBuf::from(value), + Err(_) => Self::agent_data_dir(name).join(CONFIG_FILE_NAME), + } + } - pub fn agent_bin_dir(name: &str) -> PathBuf { - Self::agent_data_dir(name).join(FUNCTIONS_BIN_DIR_NAME) - } + pub fn agent_bin_dir(name: &str) -> PathBuf { + Self::agent_data_dir(name).join(FUNCTIONS_BIN_DIR_NAME) + } - pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> PathBuf { - Self::agent_data_dir(agent_name).join(format!("{rag_name}.yaml")) - } + pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> PathBuf { + Self::agent_data_dir(agent_name).join(format!("{rag_name}.yaml")) + } - pub fn agent_functions_file(name: &str) -> Result { - let allowed = ["tools.sh", "tools.py", "tools.js"]; + pub fn agent_functions_file(name: &str) -> Result { + let allowed = ["tools.sh", "tools.py", "tools.js"]; - for entry in read_dir(Self::agent_data_dir(name))? { - let entry = entry?; - if let Some(file) = entry.file_name().to_str() { - if allowed.contains(&file) { - return Ok(entry.path()); - } - } - } + for entry in read_dir(Self::agent_data_dir(name))? { + let entry = entry?; + if let Some(file) = entry.file_name().to_str() { + if allowed.contains(&file) { + return Ok(entry.path()); + } + } + } - Err(anyhow!( + Err(anyhow!( "No tools script found in agent functions directory" )) - } + } - pub fn models_override_file() -> PathBuf { - Self::local_path("models-override.yaml") - } + pub fn models_override_file() -> PathBuf { + Self::local_path("models-override.yaml") + } - pub fn state(&self) -> StateFlags { - let mut flags = StateFlags::empty(); - if let Some(session) = &self.session { - if session.is_empty() { - flags |= StateFlags::SESSION_EMPTY; - } else { - flags |= StateFlags::SESSION; - } - if session.role_name().is_some() { - flags |= StateFlags::ROLE; - } - } else if self.role.is_some() { - flags |= StateFlags::ROLE; - } - if self.agent.is_some() { - flags |= StateFlags::AGENT; - } - if self.rag.is_some() { - flags |= StateFlags::RAG; - } - flags - } + pub fn state(&self) -> StateFlags { + let mut flags = StateFlags::empty(); + if let Some(session) = &self.session { + if session.is_empty() { + flags |= StateFlags::SESSION_EMPTY; + } else { + flags |= StateFlags::SESSION; + } + if session.role_name().is_some() { + flags |= StateFlags::ROLE; + } + } else if self.role.is_some() { + flags |= StateFlags::ROLE; + } + if self.agent.is_some() { + flags |= StateFlags::AGENT; + } + if self.rag.is_some() { + flags |= StateFlags::RAG; + } + flags + } - pub fn serve_addr(&self) -> String { - self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into()) - } + pub fn serve_addr(&self) -> String { + self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into()) + } - pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option)> { - let log_level = env::var(get_env_name("log_level")) - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(match cfg!(debug_assertions) { - true => LevelFilter::Debug, - false => { - if is_serve { - LevelFilter::Off - } else { - LevelFilter::Info - } - } - }); - let log_path = match env::var(get_env_name("log_path")) { - Ok(v) => Some(PathBuf::from(v)), - Err(_) => match is_serve { - true => None, - false => Some(Config::log_path()), - }, - }; - Ok((log_level, log_path)) - } + pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option)> { + let log_level = env::var(get_env_name("log_level")) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(match cfg!(debug_assertions) { + true => LevelFilter::Debug, + false => { + if is_serve { + LevelFilter::Off + } else { + LevelFilter::Info + } + } + }); + let log_path = match env::var(get_env_name("log_path")) { + Ok(v) => Some(PathBuf::from(v)), + Err(_) => match is_serve { + true => None, + false => Some(Config::log_path()), + }, + }; + Ok((log_level, log_path)) + } - pub fn edit_config(&self) -> Result<()> { - let config_path = Self::config_file(); - let editor = self.editor()?; - edit_file(&editor, &config_path)?; - println!( - "NOTE: Remember to restart {} if there are changes made to '{}'", - env!("CARGO_CRATE_NAME"), - config_path.display(), - ); - Ok(()) - } + pub fn edit_config(&self) -> Result<()> { + let config_path = Self::config_file(); + let editor = self.editor()?; + edit_file(&editor, &config_path)?; + println!( + "NOTE: Remember to restart {} if there are changes made to '{}'", + env!("CARGO_CRATE_NAME"), + config_path.display(), + ); + Ok(()) + } - pub fn current_model(&self) -> &Model { - if let Some(session) = self.session.as_ref() { - session.model() - } else if let Some(agent) = self.agent.as_ref() { - agent.model() - } else if let Some(role) = self.role.as_ref() { - role.model() - } else { - &self.model - } - } + pub fn current_model(&self) -> &Model { + if let Some(session) = self.session.as_ref() { + session.model() + } else if let Some(agent) = self.agent.as_ref() { + agent.model() + } else if let Some(role) = self.role.as_ref() { + role.model() + } else { + &self.model + } + } - pub fn role_like_mut(&mut self) -> Option<&mut dyn RoleLike> { - if let Some(session) = self.session.as_mut() { - Some(session) - } else if let Some(agent) = self.agent.as_mut() { - Some(agent) - } else if let Some(role) = self.role.as_mut() { - Some(role) - } else { - None - } - } + pub fn role_like_mut(&mut self) -> Option<&mut dyn RoleLike> { + if let Some(session) = self.session.as_mut() { + Some(session) + } else if let Some(agent) = self.agent.as_mut() { + Some(agent) + } else if let Some(role) = self.role.as_mut() { + Some(role) + } else { + None + } + } - pub fn extract_role(&self) -> Role { - if let Some(session) = self.session.as_ref() { - session.to_role() - } else if let Some(agent) = self.agent.as_ref() { - agent.to_role() - } else if let Some(role) = self.role.as_ref() { - role.clone() - } else { - let mut role = Role::default(); - role.batch_set( - &self.model, - self.temperature, - self.top_p, - self.use_tools.clone(), - self.use_mcp_servers.clone(), - ); - role - } - } + pub fn extract_role(&self) -> Role { + if let Some(session) = self.session.as_ref() { + session.to_role() + } else if let Some(agent) = self.agent.as_ref() { + agent.to_role() + } else if let Some(role) = self.role.as_ref() { + role.clone() + } else { + let mut role = Role::default(); + role.batch_set( + &self.model, + self.temperature, + self.top_p, + self.use_tools.clone(), + self.use_mcp_servers.clone(), + ); + role + } + } - pub fn info(&self) -> Result { - if let Some(agent) = &self.agent { - let output = agent.export()?; - if let Some(session) = &self.session { - let session = session - .export()? - .split('\n') - .map(|v| format!(" {v}")) - .collect::>() - .join("\n"); - Ok(format!("{output}session:\n{session}")) - } else { - Ok(output) - } - } else if let Some(session) = &self.session { - session.export() - } else if let Some(role) = &self.role { - Ok(role.export()) - } else if let Some(rag) = &self.rag { - rag.export() - } else { - self.sysinfo() - } - } + pub fn info(&self) -> Result { + if let Some(agent) = &self.agent { + let output = agent.export()?; + if let Some(session) = &self.session { + let session = session + .export()? + .split('\n') + .map(|v| format!(" {v}")) + .collect::>() + .join("\n"); + Ok(format!("{output}session:\n{session}")) + } else { + Ok(output) + } + } else if let Some(session) = &self.session { + session.export() + } else if let Some(role) = &self.role { + Ok(role.export()) + } else if let Some(rag) = &self.rag { + rag.export() + } else { + self.sysinfo() + } + } - pub fn sysinfo(&self) -> Result { - let display_path = |path: &Path| path.display().to_string(); - let wrap = self - .wrap - .clone() - .map_or_else(|| String::from("no"), |v| v.to_string()); - let (rag_reranker_model, rag_top_k) = match &self.rag { - Some(rag) => rag.get_config(), - None => (self.rag_reranker_model.clone(), self.rag_top_k), - }; - let role = self.extract_role(); - let mut items = vec![ - ("model", role.model().id()), - ("temperature", format_option_value(&role.temperature())), - ("top_p", format_option_value(&role.top_p())), - ("use_tools", format_option_value(&role.use_tools())), - ( - "use_mcp_servers", - format_option_value(&role.use_mcp_servers()), - ), - ( - "max_output_tokens", - role.model() - .max_tokens_param() - .map(|v| format!("{v} (current model)")) - .unwrap_or_else(|| "null".into()), - ), - ("save_session", format_option_value(&self.save_session)), - ("compress_threshold", self.compress_threshold.to_string()), - ( - "rag_reranker_model", - format_option_value(&rag_reranker_model), - ), - ("rag_top_k", rag_top_k.to_string()), - ("dry_run", self.dry_run.to_string()), - ("function_calling", self.function_calling.to_string()), - ("mcp_servers", self.mcp_servers.to_string()), - ("stream", self.stream.to_string()), - ("save", self.save.to_string()), - ("keybindings", self.keybindings.clone()), - ("wrap", wrap), - ("wrap_code", self.wrap_code.to_string()), - ("highlight", self.highlight.to_string()), - ("theme", format_option_value(&self.theme)), - ("config_file", display_path(&Self::config_file())), - ("env_file", display_path(&Self::env_file())), - ("roles_dir", display_path(&Self::roles_dir())), - ("sessions_dir", display_path(&self.sessions_dir())), - ("rags_dir", display_path(&Self::rags_dir())), - ("macros_dir", display_path(&Self::macros_dir())), - ("functions_dir", display_path(&Self::functions_dir())), - ("messages_file", display_path(&self.messages_file())), - ( - "vault_password_file", - display_path(&self.vault_password_file()), - ), - ]; - if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) { - items.push(("log_path", display_path(&log_path))); - } - let output = items - .iter() - .map(|(name, value)| format!("{name:<24}{value}\n")) - .collect::>() - .join(""); - Ok(output) - } + pub fn sysinfo(&self) -> Result { + let display_path = |path: &Path| path.display().to_string(); + let wrap = self + .wrap + .clone() + .map_or_else(|| String::from("no"), |v| v.to_string()); + let (rag_reranker_model, rag_top_k) = match &self.rag { + Some(rag) => rag.get_config(), + None => (self.rag_reranker_model.clone(), self.rag_top_k), + }; + let role = self.extract_role(); + let mut items = vec![ + ("model", role.model().id()), + ("temperature", format_option_value(&role.temperature())), + ("top_p", format_option_value(&role.top_p())), + ("use_tools", format_option_value(&role.use_tools())), + ( + "use_mcp_servers", + format_option_value(&role.use_mcp_servers()), + ), + ( + "max_output_tokens", + role.model() + .max_tokens_param() + .map(|v| format!("{v} (current model)")) + .unwrap_or_else(|| "null".into()), + ), + ("save_session", format_option_value(&self.save_session)), + ("compress_threshold", self.compress_threshold.to_string()), + ( + "rag_reranker_model", + format_option_value(&rag_reranker_model), + ), + ("rag_top_k", rag_top_k.to_string()), + ("dry_run", self.dry_run.to_string()), + ("function_calling", self.function_calling.to_string()), + ("mcp_servers", self.mcp_servers.to_string()), + ("stream", self.stream.to_string()), + ("save", self.save.to_string()), + ("keybindings", self.keybindings.clone()), + ("wrap", wrap), + ("wrap_code", self.wrap_code.to_string()), + ("highlight", self.highlight.to_string()), + ("theme", format_option_value(&self.theme)), + ("config_file", display_path(&Self::config_file())), + ("env_file", display_path(&Self::env_file())), + ("roles_dir", display_path(&Self::roles_dir())), + ("sessions_dir", display_path(&self.sessions_dir())), + ("rags_dir", display_path(&Self::rags_dir())), + ("macros_dir", display_path(&Self::macros_dir())), + ("functions_dir", display_path(&Self::functions_dir())), + ("messages_file", display_path(&self.messages_file())), + ( + "vault_password_file", + display_path(&self.vault_password_file()), + ), + ]; + if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) { + items.push(("log_path", display_path(&log_path))); + } + let output = items + .iter() + .map(|(name, value)| format!("{name:<24}{value}\n")) + .collect::>() + .join(""); + Ok(output) + } - pub async fn update( - config: &GlobalConfig, - data: &str, - abort_signal: AbortSignal, - ) -> Result<()> { - let parts: Vec<&str> = data.split_whitespace().collect(); - if parts.len() != 2 { - bail!("Usage: .set . If value is null, unset key."); - } - let key = parts[0]; - let value = parts[1]; - match key { - "temperature" => { - let value = parse_value(value)?; - config.write().set_temperature(value); - } - "top_p" => { - let value = parse_value(value)?; - config.write().set_top_p(value); - } - "use_tools" => { - let value = parse_value(value)?; - config.write().set_use_tools(value); - } - "use_mcp_servers" => { - let value: Option = parse_value(value)?; - if let Some(servers) = value.as_ref() { - if let Some(registry) = &config.read().mcp_registry { - if registry.list_configured_servers().is_empty() { - bail!("No MCP servers are configured. Please configure MCP servers first before setting 'use_mcp_servers'."); - } + pub async fn update( + config: &GlobalConfig, + data: &str, + abort_signal: AbortSignal, + ) -> Result<()> { + let parts: Vec<&str> = data.split_whitespace().collect(); + if parts.len() != 2 { + bail!("Usage: .set . If value is null, unset key."); + } + let key = parts[0]; + let value = parts[1]; + match key { + "temperature" => { + let value = parse_value(value)?; + config.write().set_temperature(value); + } + "top_p" => { + let value = parse_value(value)?; + config.write().set_top_p(value); + } + "use_tools" => { + let value = parse_value(value)?; + config.write().set_use_tools(value); + } + "use_mcp_servers" => { + let value: Option = parse_value(value)?; + if let Some(servers) = value.as_ref() { + if let Some(registry) = &config.read().mcp_registry { + if registry.list_configured_servers().is_empty() { + bail!("No MCP servers are configured. Please configure MCP servers first before setting 'use_mcp_servers'."); + } - if !servers.split(',').all(|s| { - registry - .list_configured_servers() - .contains(&s.trim().to_string()) - || s == "all" - }) { - bail!("Some of the specified MCP servers in 'use_mcp_servers' are configured. Please check your MCP server configuration."); - } - } - } - config.write().set_use_mcp_servers(value.clone()); - if config.read().mcp_servers { - config.write().functions.clear_mcp_meta_functions(); - let registry = config - .write() - .mcp_registry - .take() - .expect("MCP registry should be initialized"); - let new_mcp_registry = - McpRegistry::reinit(registry, value, abort_signal.clone()).await?; + if !servers.split(',').all(|s| { + registry + .list_configured_servers() + .contains(&s.trim().to_string()) + || s == "all" + }) { + bail!("Some of the specified MCP servers in 'use_mcp_servers' are configured. Please check your MCP server configuration."); + } + } + } + config.write().set_use_mcp_servers(value.clone()); + if config.read().mcp_servers { + config.write().functions.clear_mcp_meta_functions(); + let registry = config + .write() + .mcp_registry + .take() + .expect("MCP registry should be initialized"); + let new_mcp_registry = + McpRegistry::reinit(registry, value, abort_signal.clone()).await?; - if !new_mcp_registry.is_empty() { - config - .write() - .functions - .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); - } + if !new_mcp_registry.is_empty() { + config + .write() + .functions + .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); + } - config.write().mcp_registry = Some(new_mcp_registry); - } - } - "max_output_tokens" => { - let value = parse_value(value)?; - config.write().set_max_output_tokens(value); - } - "save_session" => { - let value = parse_value(value)?; - config.write().set_save_session(value); - } - "compress_threshold" => { - let value = parse_value(value)?; - config.write().set_compress_threshold(value); - } - "rag_reranker_model" => { - let value = parse_value(value)?; - Self::set_rag_reranker_model(config, value)?; - } - "rag_top_k" => { - let value = value.parse().with_context(|| "Invalid value")?; - Self::set_rag_top_k(config, value)?; - } - "dry_run" => { - let value = value.parse().with_context(|| "Invalid value")?; - config.write().dry_run = value; - } - "function_calling" => { - let value = value.parse().with_context(|| "Invalid value")?; - if value && config.write().functions.is_empty() { - bail!("Function calling cannot be enabled because no functions are installed.") - } - config.write().function_calling = value; - } - "mcp_servers" => { - let value = value.parse().with_context(|| "Invalid value")?; - config.write().functions.clear_mcp_meta_functions(); + config.write().mcp_registry = Some(new_mcp_registry); + } + } + "max_output_tokens" => { + let value = parse_value(value)?; + config.write().set_max_output_tokens(value); + } + "save_session" => { + let value = parse_value(value)?; + config.write().set_save_session(value); + } + "compress_threshold" => { + let value = parse_value(value)?; + config.write().set_compress_threshold(value); + } + "rag_reranker_model" => { + let value = parse_value(value)?; + Self::set_rag_reranker_model(config, value)?; + } + "rag_top_k" => { + let value = value.parse().with_context(|| "Invalid value")?; + Self::set_rag_top_k(config, value)?; + } + "dry_run" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().dry_run = value; + } + "function_calling" => { + let value = value.parse().with_context(|| "Invalid value")?; + if value && config.write().functions.is_empty() { + bail!("Function calling cannot be enabled because no functions are installed.") + } + config.write().function_calling = value; + } + "mcp_servers" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().functions.clear_mcp_meta_functions(); - let registry = config - .write() - .mcp_registry - .take() - .expect("MCP registry should be initialized"); - let use_mcp_servers = if value { - config.read().use_mcp_servers.clone() - } else { - None - }; - let new_registry = - McpRegistry::reinit(registry, use_mcp_servers, abort_signal.clone()).await?; - if !new_registry.is_empty() && value { - config - .write() - .functions - .append_mcp_meta_functions(new_registry.list_started_servers()); - } - config.write().mcp_registry = Some(new_registry); - config.write().mcp_servers = value; - } - "stream" => { - let value = value.parse().with_context(|| "Invalid value")?; - config.write().stream = value; - } - "save" => { - let value = value.parse().with_context(|| "Invalid value")?; - config.write().save = value; - } - "highlight" => { - let value = value.parse().with_context(|| "Invalid value")?; - config.write().highlight = value; - } - _ => bail!("Unknown key '{key}'"), - } - Ok(()) - } + let registry = config + .write() + .mcp_registry + .take() + .expect("MCP registry should be initialized"); + let use_mcp_servers = if value { + config.read().use_mcp_servers.clone() + } else { + None + }; + let new_registry = + McpRegistry::reinit(registry, use_mcp_servers, abort_signal.clone()).await?; + if !new_registry.is_empty() && value { + config + .write() + .functions + .append_mcp_meta_functions(new_registry.list_started_servers()); + } + config.write().mcp_registry = Some(new_registry); + config.write().mcp_servers = value; + } + "stream" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().stream = value; + } + "save" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().save = value; + } + "highlight" => { + let value = value.parse().with_context(|| "Invalid value")?; + config.write().highlight = value; + } + _ => bail!("Unknown key '{key}'"), + } + Ok(()) + } - pub fn delete(config: &GlobalConfig, kind: &str) -> Result<()> { - let (dir, file_ext) = match kind { - "role" => (Self::roles_dir(), Some(".md")), - "session" => (config.read().sessions_dir(), Some(".yaml")), - "rag" => (Self::rags_dir(), Some(".yaml")), - "macro" => (Self::macros_dir(), Some(".yaml")), - "agent-data" => (Self::agents_data_dir(), None), - _ => bail!("Unknown kind '{kind}'"), - }; - let names = match read_dir(&dir) { - Ok(rd) => { - let mut names = vec![]; - for entry in rd.flatten() { - let name = entry.file_name(); - match file_ext { - Some(file_ext) => { - if let Some(name) = name.to_string_lossy().strip_suffix(file_ext) { - names.push(name.to_string()); - } - } - None => { - if entry.path().is_dir() { - names.push(name.to_string_lossy().to_string()); - } - } - } - } - names.sort_unstable(); - names - } - Err(_) => vec![], - }; + pub fn delete(config: &GlobalConfig, kind: &str) -> Result<()> { + let (dir, file_ext) = match kind { + "role" => (Self::roles_dir(), Some(".md")), + "session" => (config.read().sessions_dir(), Some(".yaml")), + "rag" => (Self::rags_dir(), Some(".yaml")), + "macro" => (Self::macros_dir(), Some(".yaml")), + "agent-data" => (Self::agents_data_dir(), None), + _ => bail!("Unknown kind '{kind}'"), + }; + let names = match read_dir(&dir) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd.flatten() { + let name = entry.file_name(); + match file_ext { + Some(file_ext) => { + if let Some(name) = name.to_string_lossy().strip_suffix(file_ext) { + names.push(name.to_string()); + } + } + None => { + if entry.path().is_dir() { + names.push(name.to_string_lossy().to_string()); + } + } + } + } + names.sort_unstable(); + names + } + Err(_) => vec![], + }; - if names.is_empty() { - bail!("No {kind} to delete") - } + if names.is_empty() { + bail!("No {kind} to delete") + } - let select_names = MultiSelect::new(&format!("Select {kind} to delete:"), names) - .with_validator(|list: &[ListOption<&String>]| { - if list.is_empty() { - Ok(Validation::Invalid( - "At least one item must be selected".into(), - )) - } else { - Ok(Validation::Valid) - } - }) - .prompt()?; + let select_names = MultiSelect::new(&format!("Select {kind} to delete:"), names) + .with_validator(|list: &[ListOption<&String>]| { + if list.is_empty() { + Ok(Validation::Invalid( + "At least one item must be selected".into(), + )) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; - for name in select_names { - match file_ext { - Some(ext) => { - let path = dir.join(format!("{name}{ext}")); - remove_file(&path).with_context(|| { - format!("Failed to delete {kind} at '{}'", path.display()) - })?; - } - None => { - let path = dir.join(name); - remove_dir_all(&path).with_context(|| { - format!("Failed to delete {kind} at '{}'", path.display()) - })?; - } - } - } - println!("✓ Successfully deleted {kind}."); - Ok(()) - } + for name in select_names { + match file_ext { + Some(ext) => { + let path = dir.join(format!("{name}{ext}")); + remove_file(&path).with_context(|| { + format!("Failed to delete {kind} at '{}'", path.display()) + })?; + } + None => { + let path = dir.join(name); + remove_dir_all(&path).with_context(|| { + format!("Failed to delete {kind} at '{}'", path.display()) + })?; + } + } + } + println!("✓ Successfully deleted {kind}."); + Ok(()) + } - pub fn set_temperature(&mut self, value: Option) { - match self.role_like_mut() { - Some(role_like) => role_like.set_temperature(value), - None => self.temperature = value, - } - } + pub fn set_temperature(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_temperature(value), + None => self.temperature = value, + } + } - pub fn set_top_p(&mut self, value: Option) { - match self.role_like_mut() { - Some(role_like) => role_like.set_top_p(value), - None => self.top_p = value, - } - } + pub fn set_top_p(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_top_p(value), + None => self.top_p = value, + } + } - pub fn set_use_tools(&mut self, value: Option) { - match self.role_like_mut() { - Some(role_like) => role_like.set_use_tools(value), - None => self.use_tools = value, - } - } + pub fn set_use_tools(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_use_tools(value), + None => self.use_tools = value, + } + } - pub fn set_use_mcp_servers(&mut self, value: Option) { - match self.role_like_mut() { - Some(role_like) => role_like.set_use_mcp_servers(value), - None => self.use_mcp_servers = value, - } - } + pub fn set_use_mcp_servers(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => role_like.set_use_mcp_servers(value), + None => self.use_mcp_servers = value, + } + } - pub fn set_save_session(&mut self, value: Option) { - if let Some(session) = self.session.as_mut() { - session.set_save_session(value); - } else { - self.save_session = value; - } - } + pub fn set_save_session(&mut self, value: Option) { + if let Some(session) = self.session.as_mut() { + session.set_save_session(value); + } else { + self.save_session = value; + } + } - pub fn set_compress_threshold(&mut self, value: Option) { - if let Some(session) = self.session.as_mut() { - session.set_compress_threshold(value); - } else { - self.compress_threshold = value.unwrap_or_default(); - } - } + pub fn set_compress_threshold(&mut self, value: Option) { + if let Some(session) = self.session.as_mut() { + session.set_compress_threshold(value); + } else { + self.compress_threshold = value.unwrap_or_default(); + } + } - pub fn set_rag_reranker_model(config: &GlobalConfig, value: Option) -> Result<()> { - if let Some(id) = &value { - Model::retrieve_model(&config.read(), id, ModelType::Reranker)?; - } - let has_rag = config.read().rag.is_some(); - match has_rag { - true => update_rag(config, |rag| { - rag.set_reranker_model(value)?; - Ok(()) - })?, - false => config.write().rag_reranker_model = value, - } - Ok(()) - } + pub fn set_rag_reranker_model(config: &GlobalConfig, value: Option) -> Result<()> { + if let Some(id) = &value { + Model::retrieve_model(&config.read(), id, ModelType::Reranker)?; + } + let has_rag = config.read().rag.is_some(); + match has_rag { + true => update_rag(config, |rag| { + rag.set_reranker_model(value)?; + Ok(()) + })?, + false => config.write().rag_reranker_model = value, + } + Ok(()) + } - pub fn set_rag_top_k(config: &GlobalConfig, value: usize) -> Result<()> { - let has_rag = config.read().rag.is_some(); - match has_rag { - true => update_rag(config, |rag| { - rag.set_top_k(value)?; - Ok(()) - })?, - false => config.write().rag_top_k = value, - } - Ok(()) - } + pub fn set_rag_top_k(config: &GlobalConfig, value: usize) -> Result<()> { + let has_rag = config.read().rag.is_some(); + match has_rag { + true => update_rag(config, |rag| { + rag.set_top_k(value)?; + Ok(()) + })?, + false => config.write().rag_top_k = value, + } + Ok(()) + } - pub fn set_wrap(&mut self, value: &str) -> Result<()> { - if value == "no" { - self.wrap = None; - } else if value == "auto" { - self.wrap = Some(value.into()); - } else { - value - .parse::() - .map_err(|_| anyhow!("Invalid wrap value"))?; - self.wrap = Some(value.into()) - } - Ok(()) - } + pub fn set_wrap(&mut self, value: &str) -> Result<()> { + if value == "no" { + self.wrap = None; + } else if value == "auto" { + self.wrap = Some(value.into()); + } else { + value + .parse::() + .map_err(|_| anyhow!("Invalid wrap value"))?; + self.wrap = Some(value.into()) + } + Ok(()) + } - pub fn set_max_output_tokens(&mut self, value: Option) { - match self.role_like_mut() { - Some(role_like) => { - let mut model = role_like.model().clone(); - model.set_max_tokens(value, true); - role_like.set_model(model); - } - None => { - self.model.set_max_tokens(value, true); - } - }; - } + pub fn set_max_output_tokens(&mut self, value: Option) { + match self.role_like_mut() { + Some(role_like) => { + let mut model = role_like.model().clone(); + model.set_max_tokens(value, true); + role_like.set_model(model); + } + None => { + self.model.set_max_tokens(value, true); + } + }; + } - pub fn set_model(&mut self, model_id: &str) -> Result<()> { - let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; - match self.role_like_mut() { - Some(role_like) => role_like.set_model(model), - None => { - self.model = model; - } - } - Ok(()) - } + pub fn set_model(&mut self, model_id: &str) -> Result<()> { + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; + match self.role_like_mut() { + Some(role_like) => role_like.set_model(model), + None => { + self.model = model; + } + } + Ok(()) + } - pub fn use_prompt(&mut self, prompt: &str) -> Result<()> { - let mut role = Role::new(TEMP_ROLE_NAME, prompt); - role.set_model(self.current_model().clone()); - self.use_role_obj(role) - } + pub fn use_prompt(&mut self, prompt: &str) -> Result<()> { + let mut role = Role::new(TEMP_ROLE_NAME, prompt); + role.set_model(self.current_model().clone()); + self.use_role_obj(role) + } - pub async fn use_role_safely( - config: &GlobalConfig, - name: &str, - abort_signal: AbortSignal, - ) -> Result<()> { - let mut cfg = { - let mut guard = config.write(); - take(&mut *guard) - }; + pub async fn use_role_safely( + config: &GlobalConfig, + name: &str, + abort_signal: AbortSignal, + ) -> Result<()> { + let mut cfg = { + let mut guard = config.write(); + take(&mut *guard) + }; - cfg.use_role(name, abort_signal.clone()).await?; + cfg.use_role(name, abort_signal.clone()).await?; - { - let mut guard = config.write(); - *guard = cfg; - } + { + let mut guard = config.write(); + *guard = cfg; + } - Ok(()) - } + Ok(()) + } - pub async fn use_role(&mut self, name: &str, abort_signal: AbortSignal) -> Result<()> { - let role = self.retrieve_role(name)?; - let mcp_servers = if self.mcp_servers { - role.use_mcp_servers() - } else { - eprintln!( - "{}", - formatdoc!( + pub async fn use_role(&mut self, name: &str, abort_signal: AbortSignal) -> Result<()> { + let role = self.retrieve_role(name)?; + let mcp_servers = if self.mcp_servers { + role.use_mcp_servers() + } else { + eprintln!( + "{}", + formatdoc!( " This role uses MCP servers, but MCP support is disabled. To enable it, exit the role and set 'mcp_servers: true', then try again " ) - ); - None - }; - self.functions.clear_mcp_meta_functions(); - let registry = self - .mcp_registry - .take() - .with_context(|| "MCP registry should be populated")?; - let new_mcp_registry = - McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; + ); + None + }; + self.functions.clear_mcp_meta_functions(); + let registry = self + .mcp_registry + .take() + .with_context(|| "MCP registry should be populated")?; + let new_mcp_registry = + McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; - if !new_mcp_registry.is_empty() { - self.functions - .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); - } + if !new_mcp_registry.is_empty() { + self.functions + .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); + } - self.mcp_registry = Some(new_mcp_registry); - self.use_role_obj(role) - } + self.mcp_registry = Some(new_mcp_registry); + self.use_role_obj(role) + } - pub fn use_role_obj(&mut self, role: Role) -> Result<()> { - if self.agent.is_some() { - bail!("Cannot perform this operation because you are using a agent") - } - if let Some(session) = self.session.as_mut() { - session.guard_empty()?; - session.set_role(role); - } else { - self.role = Some(role); - } - Ok(()) - } + pub fn use_role_obj(&mut self, role: Role) -> Result<()> { + if self.agent.is_some() { + bail!("Cannot perform this operation because you are using a agent") + } + if let Some(session) = self.session.as_mut() { + session.guard_empty()?; + session.set_role(role); + } else { + self.role = Some(role); + } + Ok(()) + } - pub fn role_info(&self) -> Result { - if let Some(session) = &self.session { - if session.role_name().is_some() { - let role = session.to_role(); - Ok(role.export()) - } else { - bail!("No session role") - } - } else if let Some(role) = &self.role { - Ok(role.export()) - } else { - bail!("No role") - } - } + pub fn role_info(&self) -> Result { + if let Some(session) = &self.session { + if session.role_name().is_some() { + let role = session.to_role(); + Ok(role.export()) + } else { + bail!("No session role") + } + } else if let Some(role) = &self.role { + Ok(role.export()) + } else { + bail!("No role") + } + } - pub fn exit_role(&mut self) -> Result<()> { - if let Some(session) = self.session.as_mut() { - session.guard_empty()?; - session.clear_role(); - } else if self.role.is_some() { - self.role = None; - } - Ok(()) - } + pub fn exit_role(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + session.guard_empty()?; + session.clear_role(); + } else if self.role.is_some() { + self.role = None; + } + Ok(()) + } - pub fn retrieve_role(&self, name: &str) -> Result { - let names = Self::list_roles(false); - let mut role = if names.contains(&name.to_string()) { - let path = Self::role_file(name); - let content = read_to_string(&path)?; - Role::new(name, &content) - } else { - Role::builtin(name)? - }; - let current_model = self.current_model().clone(); - match role.model_id() { - Some(model_id) => { - if current_model.id() != model_id { - let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; - role.set_model(model); - } else { - role.set_model(current_model); - } - } - None => { - role.set_model(current_model); - if role.temperature().is_none() { - role.set_temperature(self.temperature); - } - if role.top_p().is_none() { - role.set_top_p(self.top_p); - } - } - } - Ok(role) - } + pub fn retrieve_role(&self, name: &str) -> Result { + let names = Self::list_roles(false); + let mut role = if names.contains(&name.to_string()) { + let path = Self::role_file(name); + let content = read_to_string(&path)?; + Role::new(name, &content) + } else { + Role::builtin(name)? + }; + let current_model = self.current_model().clone(); + match role.model_id() { + Some(model_id) => { + if current_model.id() != model_id { + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; + role.set_model(model); + } else { + role.set_model(current_model); + } + } + None => { + role.set_model(current_model); + if role.temperature().is_none() { + role.set_temperature(self.temperature); + } + if role.top_p().is_none() { + role.set_top_p(self.top_p); + } + } + } + Ok(role) + } - pub fn new_role(&mut self, name: &str) -> Result<()> { - if self.macro_flag { - bail!("No role"); - } - let ans = Confirm::new("Create a new role?") - .with_default(true) - .prompt()?; - if ans { - self.upsert_role(name)?; - } else { - bail!("No role"); - } - Ok(()) - } + pub fn new_role(&mut self, name: &str) -> Result<()> { + if self.macro_flag { + bail!("No role"); + } + let ans = Confirm::new("Create a new role?") + .with_default(true) + .prompt()?; + if ans { + self.upsert_role(name)?; + } else { + bail!("No role"); + } + Ok(()) + } - pub async fn edit_role(&mut self, abort_signal: AbortSignal) -> Result<()> { - let role_name; - if let Some(session) = self.session.as_ref() { - if let Some(name) = session.role_name().map(|v| v.to_string()) { - if session.is_empty() { - role_name = Some(name); - } else { - bail!("Cannot perform this operation because you are in a non-empty session") - } - } else { - bail!("No role") - } - } else { - role_name = self.role.as_ref().map(|v| v.name().to_string()); - } - let name = role_name.ok_or_else(|| anyhow!("No role"))?; - self.upsert_role(&name)?; - self.use_role(&name, abort_signal.clone()).await - } + pub async fn edit_role(&mut self, abort_signal: AbortSignal) -> Result<()> { + let role_name; + if let Some(session) = self.session.as_ref() { + if let Some(name) = session.role_name().map(|v| v.to_string()) { + if session.is_empty() { + role_name = Some(name); + } else { + bail!("Cannot perform this operation because you are in a non-empty session") + } + } else { + bail!("No role") + } + } else { + role_name = self.role.as_ref().map(|v| v.name().to_string()); + } + let name = role_name.ok_or_else(|| anyhow!("No role"))?; + self.upsert_role(&name)?; + self.use_role(&name, abort_signal.clone()).await + } - pub fn upsert_role(&mut self, name: &str) -> Result<()> { - let role_path = Self::role_file(name); - ensure_parent_exists(&role_path)?; - let editor = self.editor()?; - edit_file(&editor, &role_path)?; - if self.working_mode.is_repl() { - println!("✓ Saved the role to '{}'.", role_path.display()); - } - Ok(()) - } + pub fn upsert_role(&mut self, name: &str) -> Result<()> { + let role_path = Self::role_file(name); + ensure_parent_exists(&role_path)?; + let editor = self.editor()?; + edit_file(&editor, &role_path)?; + if self.working_mode.is_repl() { + println!("✓ Saved the role to '{}'.", role_path.display()); + } + Ok(()) + } - pub fn save_role(&mut self, name: Option<&str>) -> Result<()> { - let mut role_name = match &self.role { - Some(role) => { - if role.has_args() { - bail!("Unable to save the role with arguments (whose name contains '#')") - } - match name { - Some(v) => v.to_string(), - None => role.name().to_string(), - } - } - None => bail!("No role"), - }; - if role_name == TEMP_ROLE_NAME { - role_name = Text::new("Role name:") - .with_validator(|input: &str| { - let input = input.trim(); - if input.is_empty() { - Ok(Validation::Invalid("This name is required".into())) - } else if input == TEMP_ROLE_NAME { - Ok(Validation::Invalid("This name is reserved".into())) - } else { - Ok(Validation::Valid) - } - }) - .prompt()?; - } - let role_path = Self::role_file(&role_name); - if let Some(role) = self.role.as_mut() { - role.save(&role_name, &role_path, self.working_mode.is_repl())?; - } + pub fn save_role(&mut self, name: Option<&str>) -> Result<()> { + let mut role_name = match &self.role { + Some(role) => { + if role.has_args() { + bail!("Unable to save the role with arguments (whose name contains '#')") + } + match name { + Some(v) => v.to_string(), + None => role.name().to_string(), + } + } + None => bail!("No role"), + }; + if role_name == TEMP_ROLE_NAME { + role_name = Text::new("Role name:") + .with_validator(|input: &str| { + let input = input.trim(); + if input.is_empty() { + Ok(Validation::Invalid("This name is required".into())) + } else if input == TEMP_ROLE_NAME { + Ok(Validation::Invalid("This name is reserved".into())) + } else { + Ok(Validation::Valid) + } + }) + .prompt()?; + } + let role_path = Self::role_file(&role_name); + if let Some(role) = self.role.as_mut() { + role.save(&role_name, &role_path, self.working_mode.is_repl())?; + } - Ok(()) - } + Ok(()) + } - pub fn all_roles() -> Vec { - let mut roles: HashMap = Role::list_builtin_roles() - .iter() - .map(|v| (v.name().to_string(), v.clone())) - .collect(); - let names = Self::list_roles(false); - for name in names { - if let Ok(content) = read_to_string(Self::role_file(&name)) { - let role = Role::new(&name, &content); - roles.insert(name, role); - } - } - let mut roles: Vec<_> = roles.into_values().collect(); - roles.sort_unstable_by(|a, b| a.name().cmp(b.name())); - roles - } + pub fn all_roles() -> Vec { + let mut roles: HashMap = Role::list_builtin_roles() + .iter() + .map(|v| (v.name().to_string(), v.clone())) + .collect(); + let names = Self::list_roles(false); + for name in names { + if let Ok(content) = read_to_string(Self::role_file(&name)) { + let role = Role::new(&name, &content); + roles.insert(name, role); + } + } + let mut roles: Vec<_> = roles.into_values().collect(); + roles.sort_unstable_by(|a, b| a.name().cmp(b.name())); + roles + } - pub fn list_roles(with_builtin: bool) -> Vec { - let mut names = HashSet::new(); - if let Ok(rd) = read_dir(Self::roles_dir()) { - for entry in rd.flatten() { - if let Some(name) = entry - .file_name() - .to_str() - .and_then(|v| v.strip_suffix(".md")) - { - names.insert(name.to_string()); - } - } - } - if with_builtin { - names.extend(Role::list_builtin_role_names()); - } - let mut names: Vec<_> = names.into_iter().collect(); - names.sort_unstable(); - names - } + pub fn list_roles(with_builtin: bool) -> Vec { + let mut names = HashSet::new(); + if let Ok(rd) = read_dir(Self::roles_dir()) { + for entry in rd.flatten() { + if let Some(name) = entry + .file_name() + .to_str() + .and_then(|v| v.strip_suffix(".md")) + { + names.insert(name.to_string()); + } + } + } + if with_builtin { + names.extend(Role::list_builtin_role_names()); + } + let mut names: Vec<_> = names.into_iter().collect(); + names.sort_unstable(); + names + } - pub fn has_role(name: &str) -> bool { - let names = Self::list_roles(true); - names.contains(&name.to_string()) - } + pub fn has_role(name: &str) -> bool { + let names = Self::list_roles(true); + names.contains(&name.to_string()) + } - pub async fn use_session_safely( - config: &GlobalConfig, - session_name: Option<&str>, - abort_signal: AbortSignal, - ) -> Result<()> { - let mut cfg = { - let mut guard = config.write(); - take(&mut *guard) - }; + pub async fn use_session_safely( + config: &GlobalConfig, + session_name: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + let mut cfg = { + let mut guard = config.write(); + take(&mut *guard) + }; - cfg.use_session(session_name, abort_signal.clone()).await?; + cfg.use_session(session_name, abort_signal.clone()).await?; - { - let mut guard = config.write(); - *guard = cfg; - } + { + let mut guard = config.write(); + *guard = cfg; + } - Ok(()) - } + Ok(()) + } - pub async fn use_session( - &mut self, - session_name: Option<&str>, - abort_signal: AbortSignal, - ) -> Result<()> { - if self.session.is_some() { - bail!( + pub async fn use_session( + &mut self, + session_name: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + if self.session.is_some() { + bail!( "Already in a session, please run '.exit session' first to exit the current session." ); - } - let mut session; - match session_name { - None | Some(TEMP_SESSION_NAME) => { - let session_file = self.session_file(TEMP_SESSION_NAME); - if session_file.exists() { - remove_file(session_file).with_context(|| { - format!("Failed to cleanup previous '{TEMP_SESSION_NAME}' session") - })?; - } - session = Some(Session::new(self, TEMP_SESSION_NAME)); - } - Some(name) => { - let session_path = self.session_file(name); - if !session_path.exists() { - session = Some(Session::new(self, name)); - } else { - session = Some(Session::load(self, name, &session_path)?); - } - } - } - let mut new_session = false; - if let Some(session) = session.as_mut() { - let mcp_servers = if self.mcp_servers { - session.use_mcp_servers() - } else { - eprintln!( - "{}", - formatdoc!( + } + let mut session; + match session_name { + None | Some(TEMP_SESSION_NAME) => { + let session_file = self.session_file(TEMP_SESSION_NAME); + if session_file.exists() { + remove_file(session_file).with_context(|| { + format!("Failed to cleanup previous '{TEMP_SESSION_NAME}' session") + })?; + } + session = Some(Session::new(self, TEMP_SESSION_NAME)); + } + Some(name) => { + let session_path = self.session_file(name); + if !session_path.exists() { + session = Some(Session::new(self, name)); + } else { + session = Some(Session::load(self, name, &session_path)?); + } + } + } + let mut new_session = false; + if let Some(session) = session.as_mut() { + let mcp_servers = if self.mcp_servers { + session.use_mcp_servers() + } else { + eprintln!( + "{}", + formatdoc!( " This session uses MCP servers, but MCP support is disabled. To enable it, exit the session and set 'mcp_servers: true', then try again " ) - ); - None - }; - self.functions.clear_mcp_meta_functions(); - let registry = self - .mcp_registry - .take() - .with_context(|| "MCP registry should be populated")?; - let new_mcp_registry = - McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; + ); + None + }; + self.functions.clear_mcp_meta_functions(); + let registry = self + .mcp_registry + .take() + .with_context(|| "MCP registry should be populated")?; + let new_mcp_registry = + McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?; - if !new_mcp_registry.is_empty() { - self.functions - .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); - } + if !new_mcp_registry.is_empty() { + self.functions + .append_mcp_meta_functions(new_mcp_registry.list_started_servers()); + } - self.mcp_registry = Some(new_mcp_registry); - if session.is_empty() { - new_session = true; - if let Some(LastMessage { - input, - output, - continuous, - }) = &self.last_message - { - if (*continuous && !output.is_empty()) - && self.agent.is_some() == input.with_agent() - { - let ans = Confirm::new( - "Start a session that incorporates the last question and answer?", - ) - .with_default(false) - .prompt()?; - if ans { - session.add_message(input, output)?; - } - } - } - } - } - self.session = session; - self.init_agent_session_variables(new_session)?; - Ok(()) - } + self.mcp_registry = Some(new_mcp_registry); + if session.is_empty() { + new_session = true; + if let Some(LastMessage { + input, + output, + continuous, + }) = &self.last_message + { + if (*continuous && !output.is_empty()) + && self.agent.is_some() == input.with_agent() + { + let ans = Confirm::new( + "Start a session that incorporates the last question and answer?", + ) + .with_default(false) + .prompt()?; + if ans { + session.add_message(input, output)?; + } + } + } + } + } + self.session = session; + self.init_agent_session_variables(new_session)?; + Ok(()) + } - pub fn session_info(&self) -> Result { - if let Some(session) = &self.session { - let render_options = self.render_options()?; - let mut markdown_render = MarkdownRender::init(render_options)?; - let agent_info: Option<(String, Vec)> = self.agent.as_ref().map(|agent| { - let functions = agent - .functions() - .declarations() - .iter() - .filter_map(|v| if v.agent { Some(v.name.clone()) } else { None }) - .collect(); - (agent.name().to_string(), functions) - }); - session.render(&mut markdown_render, &agent_info) - } else { - bail!("No session") - } - } + pub fn session_info(&self) -> Result { + if let Some(session) = &self.session { + let render_options = self.render_options()?; + let mut markdown_render = MarkdownRender::init(render_options)?; + let agent_info: Option<(String, Vec)> = self.agent.as_ref().map(|agent| { + let functions = agent + .functions() + .declarations() + .iter() + .filter_map(|v| if v.agent { Some(v.name.clone()) } else { None }) + .collect(); + (agent.name().to_string(), functions) + }); + session.render(&mut markdown_render, &agent_info) + } else { + bail!("No session") + } + } - pub fn exit_session(&mut self) -> Result<()> { - if let Some(mut session) = self.session.take() { - let sessions_dir = self.sessions_dir(); - session.exit(&sessions_dir, self.working_mode.is_repl())?; - self.discontinuous_last_message(); - } - Ok(()) - } + pub fn exit_session(&mut self) -> Result<()> { + if let Some(mut session) = self.session.take() { + let sessions_dir = self.sessions_dir(); + session.exit(&sessions_dir, self.working_mode.is_repl())?; + self.discontinuous_last_message(); + } + Ok(()) + } - pub fn save_session(&mut self, name: Option<&str>) -> Result<()> { - let session_name = match &self.session { - Some(session) => match name { - Some(v) => v.to_string(), - None => session - .autoname() - .unwrap_or_else(|| session.name()) - .to_string(), - }, - None => bail!("No session"), - }; - let session_path = self.session_file(&session_name); - if let Some(session) = self.session.as_mut() { - session.save(&session_name, &session_path, self.working_mode.is_repl())?; - } - Ok(()) - } + pub fn save_session(&mut self, name: Option<&str>) -> Result<()> { + let session_name = match &self.session { + Some(session) => match name { + Some(v) => v.to_string(), + None => session + .autoname() + .unwrap_or_else(|| session.name()) + .to_string(), + }, + None => bail!("No session"), + }; + let session_path = self.session_file(&session_name); + if let Some(session) = self.session.as_mut() { + session.save(&session_name, &session_path, self.working_mode.is_repl())?; + } + Ok(()) + } - pub fn edit_session(&mut self) -> Result<()> { - let name = match &self.session { - Some(session) => session.name().to_string(), - None => bail!("No session"), - }; - let session_path = self.session_file(&name); - self.save_session(Some(&name))?; - let editor = self.editor()?; - edit_file(&editor, &session_path).with_context(|| { - format!( - "Failed to edit '{}' with '{editor}'", - session_path.display() - ) - })?; - self.session = Some(Session::load(self, &name, &session_path)?); - self.discontinuous_last_message(); - Ok(()) - } + pub fn edit_session(&mut self) -> Result<()> { + let name = match &self.session { + Some(session) => session.name().to_string(), + None => bail!("No session"), + }; + let session_path = self.session_file(&name); + self.save_session(Some(&name))?; + let editor = self.editor()?; + edit_file(&editor, &session_path).with_context(|| { + format!( + "Failed to edit '{}' with '{editor}'", + session_path.display() + ) + })?; + self.session = Some(Session::load(self, &name, &session_path)?); + self.discontinuous_last_message(); + Ok(()) + } - pub fn empty_session(&mut self) -> Result<()> { - if let Some(session) = self.session.as_mut() { - if let Some(agent) = self.agent.as_ref() { - session.sync_agent(agent); - } - session.clear_messages(); - } else { - bail!("No session") - } - self.discontinuous_last_message(); - Ok(()) - } + pub fn empty_session(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + if let Some(agent) = self.agent.as_ref() { + session.sync_agent(agent); + } + session.clear_messages(); + } else { + bail!("No session") + } + self.discontinuous_last_message(); + Ok(()) + } - pub fn set_save_session_this_time(&mut self) -> Result<()> { - if let Some(session) = self.session.as_mut() { - session.set_save_session_this_time(); - } else { - bail!("No session") - } - Ok(()) - } + pub fn set_save_session_this_time(&mut self) -> Result<()> { + if let Some(session) = self.session.as_mut() { + session.set_save_session_this_time(); + } else { + bail!("No session") + } + Ok(()) + } - pub fn list_sessions(&self) -> Vec { - list_file_names(self.sessions_dir(), ".yaml") - } + pub fn list_sessions(&self) -> Vec { + list_file_names(self.sessions_dir(), ".yaml") + } - pub fn list_autoname_sessions(&self) -> Vec { - list_file_names(self.sessions_dir().join("_"), ".yaml") - } + pub fn list_autoname_sessions(&self) -> Vec { + list_file_names(self.sessions_dir().join("_"), ".yaml") + } - pub fn maybe_compress_session(config: GlobalConfig) { - let mut need_compress = false; - { - let mut config = config.write(); - let compress_threshold = config.compress_threshold; - if let Some(session) = config.session.as_mut() { - if session.need_compress(compress_threshold) { - session.set_compressing(true); - need_compress = true; - } - } - }; - if !need_compress { - return; - } - let color = if config.read().light_theme() { - nu_ansi_term::Color::LightGray - } else { - nu_ansi_term::Color::DarkGray - }; - print!( - "\n📢 {}\n", - color.italic().paint("Compressing the session."), - ); - tokio::spawn(async move { - if let Err(err) = Config::compress_session(&config).await { - warn!("Failed to compress the session: {err}"); - } - if let Some(session) = config.write().session.as_mut() { - session.set_compressing(false); - } - }); - } + pub fn maybe_compress_session(config: GlobalConfig) { + let mut need_compress = false; + { + let mut config = config.write(); + let compress_threshold = config.compress_threshold; + if let Some(session) = config.session.as_mut() { + if session.need_compress(compress_threshold) { + session.set_compressing(true); + need_compress = true; + } + } + }; + if !need_compress { + return; + } + let color = if config.read().light_theme() { + nu_ansi_term::Color::LightGray + } else { + nu_ansi_term::Color::DarkGray + }; + print!( + "\n📢 {}\n", + color.italic().paint("Compressing the session."), + ); + tokio::spawn(async move { + if let Err(err) = Config::compress_session(&config).await { + warn!("Failed to compress the session: {err}"); + } + if let Some(session) = config.write().session.as_mut() { + session.set_compressing(false); + } + }); + } - pub async fn compress_session(config: &GlobalConfig) -> Result<()> { - match config.read().session.as_ref() { - Some(session) => { - if !session.has_user_messages() { - bail!("No need to compress since there are no messages in the session") - } - } - None => bail!("No session"), - } + pub async fn compress_session(config: &GlobalConfig) -> Result<()> { + match config.read().session.as_ref() { + Some(session) => { + if !session.has_user_messages() { + bail!("No need to compress since there are no messages in the session") + } + } + None => bail!("No session"), + } - let prompt = config - .read() - .summarize_prompt - .clone() - .unwrap_or_else(|| SUMMARIZE_PROMPT.into()); - let input = Input::from_str(config, &prompt, None); - let summary = input.fetch_chat_text().await?; - let summary_prompt = config - .read() - .summary_prompt - .clone() - .unwrap_or_else(|| SUMMARY_PROMPT.into()); - if let Some(session) = config.write().session.as_mut() { - session.compress(format!("{summary_prompt}{summary}")); - } - config.write().discontinuous_last_message(); - Ok(()) - } + let prompt = config + .read() + .summarize_prompt + .clone() + .unwrap_or_else(|| SUMMARIZE_PROMPT.into()); + let input = Input::from_str(config, &prompt, None); + let summary = input.fetch_chat_text().await?; + let summary_prompt = config + .read() + .summary_prompt + .clone() + .unwrap_or_else(|| SUMMARY_PROMPT.into()); + if let Some(session) = config.write().session.as_mut() { + session.compress(format!("{summary_prompt}{summary}")); + } + config.write().discontinuous_last_message(); + Ok(()) + } - pub fn is_compressing_session(&self) -> bool { - self.session - .as_ref() - .map(|v| v.compressing()) - .unwrap_or_default() - } + pub fn is_compressing_session(&self) -> bool { + self.session + .as_ref() + .map(|v| v.compressing()) + .unwrap_or_default() + } - pub fn maybe_autoname_session(config: GlobalConfig) { - let mut need_autoname = false; - if let Some(session) = config.write().session.as_mut() { - if session.need_autoname() { - session.set_autonaming(true); - need_autoname = true; - } - } - if !need_autoname { - return; - } - let color = if config.read().light_theme() { - nu_ansi_term::Color::LightGray - } else { - nu_ansi_term::Color::DarkGray - }; - print!("\n📢 {}\n", color.italic().paint("Autonaming the session."),); - tokio::spawn(async move { - if let Err(err) = Config::autoname_session(&config).await { - warn!("Failed to autonaming the session: {err}"); - } - if let Some(session) = config.write().session.as_mut() { - session.set_autonaming(false); - } - }); - } + pub fn maybe_autoname_session(config: GlobalConfig) { + let mut need_autoname = false; + if let Some(session) = config.write().session.as_mut() { + if session.need_autoname() { + session.set_autonaming(true); + need_autoname = true; + } + } + if !need_autoname { + return; + } + let color = if config.read().light_theme() { + nu_ansi_term::Color::LightGray + } else { + nu_ansi_term::Color::DarkGray + }; + print!("\n📢 {}\n", color.italic().paint("Autonaming the session."),); + tokio::spawn(async move { + if let Err(err) = Config::autoname_session(&config).await { + warn!("Failed to autonaming the session: {err}"); + } + if let Some(session) = config.write().session.as_mut() { + session.set_autonaming(false); + } + }); + } - pub async fn autoname_session(config: &GlobalConfig) -> Result<()> { - let text = match config - .read() - .session - .as_ref() - .and_then(|v| v.chat_history_for_autonaming()) - { - Some(v) => v, - None => bail!("No chat history"), - }; - let role = config.read().retrieve_role(CREATE_TITLE_ROLE)?; - let input = Input::from_str(config, &text, Some(role)); - let text = input.fetch_chat_text().await?; - if let Some(session) = config.write().session.as_mut() { - session.set_autoname(&text); - } - Ok(()) - } + pub async fn autoname_session(config: &GlobalConfig) -> Result<()> { + let text = match config + .read() + .session + .as_ref() + .and_then(|v| v.chat_history_for_autonaming()) + { + Some(v) => v, + None => bail!("No chat history"), + }; + let role = config.read().retrieve_role(CREATE_TITLE_ROLE)?; + let input = Input::from_str(config, &text, Some(role)); + let text = input.fetch_chat_text().await?; + if let Some(session) = config.write().session.as_mut() { + session.set_autoname(&text); + } + Ok(()) + } - pub async fn use_rag( - config: &GlobalConfig, - rag: Option<&str>, - abort_signal: AbortSignal, - ) -> Result<()> { - if config.read().agent.is_some() { - bail!("Cannot perform this operation because you are using a agent") - } - let rag = match rag { - None => { - let rag_path = config.read().rag_file(TEMP_RAG_NAME); - if rag_path.exists() { - remove_file(&rag_path).with_context(|| { - format!("Failed to cleanup previous '{TEMP_RAG_NAME}' rag") - })?; - } - Rag::init(config, TEMP_RAG_NAME, &rag_path, &[], abort_signal).await? - } - Some(name) => { - let rag_path = config.read().rag_file(name); - if !rag_path.exists() { - if config.read().working_mode.is_cmd() { - bail!("Unknown RAG '{name}'") - } - Rag::init(config, name, &rag_path, &[], abort_signal).await? - } else { - Rag::load(config, name, &rag_path)? - } - } - }; - config.write().rag = Some(Arc::new(rag)); - Ok(()) - } + pub async fn use_rag( + config: &GlobalConfig, + rag: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + if config.read().agent.is_some() { + bail!("Cannot perform this operation because you are using a agent") + } + let rag = match rag { + None => { + let rag_path = config.read().rag_file(TEMP_RAG_NAME); + if rag_path.exists() { + remove_file(&rag_path).with_context(|| { + format!("Failed to cleanup previous '{TEMP_RAG_NAME}' rag") + })?; + } + Rag::init(config, TEMP_RAG_NAME, &rag_path, &[], abort_signal).await? + } + Some(name) => { + let rag_path = config.read().rag_file(name); + if !rag_path.exists() { + if config.read().working_mode.is_cmd() { + bail!("Unknown RAG '{name}'") + } + Rag::init(config, name, &rag_path, &[], abort_signal).await? + } else { + Rag::load(config, name, &rag_path)? + } + } + }; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } - pub async fn edit_rag_docs(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { - let mut rag = match config.read().rag.clone() { - Some(v) => v.as_ref().clone(), - None => bail!("No RAG"), - }; + pub async fn edit_rag_docs(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; - let document_paths = rag.document_paths(); - let temp_file = temp_file(&format!("-rag-{}", rag.name()), ".txt"); - tokio::fs::write(&temp_file, &document_paths.join("\n")) - .await - .with_context(|| format!("Failed to write to '{}'", temp_file.display()))?; - let editor = config.read().editor()?; - edit_file(&editor, &temp_file)?; - let new_document_paths = tokio::fs::read_to_string(&temp_file) - .await - .with_context(|| format!("Failed to read '{}'", temp_file.display()))?; - let new_document_paths = new_document_paths - .split('\n') - .filter_map(|v| { - let v = v.trim(); - if v.is_empty() { - None - } else { - Some(v.to_string()) - } - }) - .collect::>(); - if new_document_paths.is_empty() || new_document_paths == document_paths { - bail!("No changes") - } - rag.refresh_document_paths(&new_document_paths, false, config, abort_signal) - .await?; - config.write().rag = Some(Arc::new(rag)); - Ok(()) - } + let document_paths = rag.document_paths(); + let temp_file = temp_file(&format!("-rag-{}", rag.name()), ".txt"); + tokio::fs::write(&temp_file, &document_paths.join("\n")) + .await + .with_context(|| format!("Failed to write to '{}'", temp_file.display()))?; + let editor = config.read().editor()?; + edit_file(&editor, &temp_file)?; + let new_document_paths = tokio::fs::read_to_string(&temp_file) + .await + .with_context(|| format!("Failed to read '{}'", temp_file.display()))?; + let new_document_paths = new_document_paths + .split('\n') + .filter_map(|v| { + let v = v.trim(); + if v.is_empty() { + None + } else { + Some(v.to_string()) + } + }) + .collect::>(); + if new_document_paths.is_empty() || new_document_paths == document_paths { + bail!("No changes") + } + rag.refresh_document_paths(&new_document_paths, false, config, abort_signal) + .await?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } - pub async fn rebuild_rag(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { - let mut rag = match config.read().rag.clone() { - Some(v) => v.as_ref().clone(), - None => bail!("No RAG"), - }; - let document_paths = rag.document_paths().to_vec(); - rag.refresh_document_paths(&document_paths, true, config, abort_signal) - .await?; - config.write().rag = Some(Arc::new(rag)); - Ok(()) - } + pub async fn rebuild_rag(config: &GlobalConfig, abort_signal: AbortSignal) -> Result<()> { + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; + let document_paths = rag.document_paths().to_vec(); + rag.refresh_document_paths(&document_paths, true, config, abort_signal) + .await?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) + } - pub fn rag_sources(config: &GlobalConfig) -> Result { - match config.read().rag.as_ref() { - Some(rag) => match rag.get_last_sources() { - Some(v) => Ok(v), - None => bail!("No sources"), - }, - None => bail!("No RAG"), - } - } + pub fn rag_sources(config: &GlobalConfig) -> Result { + match config.read().rag.as_ref() { + Some(rag) => match rag.get_last_sources() { + Some(v) => Ok(v), + None => bail!("No sources"), + }, + None => bail!("No RAG"), + } + } - pub fn rag_info(&self) -> Result { - if let Some(rag) = &self.rag { - rag.export() - } else { - bail!("No RAG") - } - } + pub fn rag_info(&self) -> Result { + if let Some(rag) = &self.rag { + rag.export() + } else { + bail!("No RAG") + } + } - pub fn exit_rag(&mut self) -> Result<()> { - self.rag.take(); - Ok(()) - } + pub fn exit_rag(&mut self) -> Result<()> { + self.rag.take(); + Ok(()) + } - pub async fn search_rag( - config: &GlobalConfig, - rag: &Rag, - text: &str, - abort_signal: AbortSignal, - ) -> Result { - let (reranker_model, top_k) = rag.get_config(); - let (embeddings, ids) = rag - .search(text, top_k, reranker_model.as_deref(), abort_signal) - .await?; - let text = config.read().rag_template(&embeddings, text); - rag.set_last_sources(&ids); - Ok(text) - } + pub async fn search_rag( + config: &GlobalConfig, + rag: &Rag, + text: &str, + abort_signal: AbortSignal, + ) -> Result { + let (reranker_model, top_k) = rag.get_config(); + let (embeddings, ids) = rag + .search(text, top_k, reranker_model.as_deref(), abort_signal) + .await?; + let text = config.read().rag_template(&embeddings, text); + rag.set_last_sources(&ids); + Ok(text) + } - pub fn list_rags() -> Vec { - match read_dir(Self::rags_dir()) { - Ok(rd) => { - let mut names = vec![]; - for entry in rd.flatten() { - let name = entry.file_name(); - if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { - names.push(name.to_string()); - } - } - names.sort_unstable(); - names - } - Err(_) => vec![], - } - } + pub fn list_rags() -> Vec { + match read_dir(Self::rags_dir()) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd.flatten() { + let name = entry.file_name(); + if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { + names.push(name.to_string()); + } + } + names.sort_unstable(); + names + } + Err(_) => vec![], + } + } - pub fn rag_template(&self, embeddings: &str, text: &str) -> String { - if embeddings.is_empty() { - return text.to_string(); - } - self.rag_template - .as_deref() - .unwrap_or(RAG_TEMPLATE) - .replace("__CONTEXT__", embeddings) - .replace("__INPUT__", text) - } + pub fn rag_template(&self, embeddings: &str, text: &str) -> String { + if embeddings.is_empty() { + return text.to_string(); + } + self.rag_template + .as_deref() + .unwrap_or(RAG_TEMPLATE) + .replace("__CONTEXT__", embeddings) + .replace("__INPUT__", text) + } - pub async fn use_agent( - config: &GlobalConfig, - agent_name: &str, - session_name: Option<&str>, - abort_signal: AbortSignal, - ) -> Result<()> { - if !config.read().function_calling { - bail!("Please enable function calling before using the agent."); - } - if config.read().agent.is_some() { - bail!("Already in an agent, please run '.exit agent' first to exit the current agent."); - } - let agent = Agent::init(config, agent_name, abort_signal.clone()).await?; - let session = session_name.map(|v| v.to_string()).or_else(|| { - if config.read().macro_flag { - None - } else { - agent.agent_session().map(|v| v.to_string()) - } - }); - config.write().rag = agent.rag(); - config.write().agent = Some(agent); - if let Some(session) = session { - Config::use_session_safely(config, Some(&session), abort_signal).await?; - } else { - config.write().init_agent_shared_variables()?; - } - Ok(()) - } + pub async fn use_agent( + config: &GlobalConfig, + agent_name: &str, + session_name: Option<&str>, + abort_signal: AbortSignal, + ) -> Result<()> { + if !config.read().function_calling { + bail!("Please enable function calling before using the agent."); + } + if config.read().agent.is_some() { + bail!("Already in an agent, please run '.exit agent' first to exit the current agent."); + } + let agent = Agent::init(config, agent_name, abort_signal.clone()).await?; + let session = session_name.map(|v| v.to_string()).or_else(|| { + if config.read().macro_flag { + None + } else { + agent.agent_session().map(|v| v.to_string()) + } + }); + config.write().rag = agent.rag(); + config.write().agent = Some(agent); + if let Some(session) = session { + Config::use_session_safely(config, Some(&session), abort_signal).await?; + } else { + config.write().init_agent_shared_variables()?; + } + Ok(()) + } - pub fn agent_info(&self) -> Result { - if let Some(agent) = &self.agent { - agent.export() - } else { - bail!("No agent") - } - } + pub fn agent_info(&self) -> Result { + if let Some(agent) = &self.agent { + agent.export() + } else { + bail!("No agent") + } + } - pub fn agent_banner(&self) -> Result { - if let Some(agent) = &self.agent { - Ok(agent.banner()) - } else { - bail!("No agent") - } - } + pub fn agent_banner(&self) -> Result { + if let Some(agent) = &self.agent { + Ok(agent.banner()) + } else { + bail!("No agent") + } + } - pub fn edit_agent_config(&self) -> Result<()> { - let agent_name = match &self.agent { - Some(agent) => agent.name(), - None => bail!("No agent"), - }; - let agent_config_path = Config::agent_config_file(agent_name); - ensure_parent_exists(&agent_config_path)?; - if !agent_config_path.exists() { - std::fs::write( - &agent_config_path, - "# see https://github.com/Dark-Alex-17/loki/blob/main/config.agent.example.yaml\n", - ) - .with_context(|| format!("Failed to write to '{}'", agent_config_path.display()))?; - } - let editor = self.editor()?; - edit_file(&editor, &agent_config_path)?; - println!( - "NOTE: Remember to reload the agent if there are changes made to '{}'", - agent_config_path.display() - ); - Ok(()) - } + pub fn edit_agent_config(&self) -> Result<()> { + let agent_name = match &self.agent { + Some(agent) => agent.name(), + None => bail!("No agent"), + }; + let agent_config_path = Config::agent_config_file(agent_name); + ensure_parent_exists(&agent_config_path)?; + if !agent_config_path.exists() { + std::fs::write( + &agent_config_path, + "# see https://github.com/Dark-Alex-17/loki/blob/main/config.agent.example.yaml\n", + ) + .with_context(|| format!("Failed to write to '{}'", agent_config_path.display()))?; + } + let editor = self.editor()?; + edit_file(&editor, &agent_config_path)?; + println!( + "NOTE: Remember to reload the agent if there are changes made to '{}'", + agent_config_path.display() + ); + Ok(()) + } - pub fn exit_agent(&mut self) -> Result<()> { - self.exit_session()?; - self.load_functions()?; - if self.agent.take().is_some() { - self.rag.take(); - self.discontinuous_last_message(); - } - Ok(()) - } + pub fn exit_agent(&mut self) -> Result<()> { + self.exit_session()?; + self.load_functions()?; + if self.agent.take().is_some() { + self.rag.take(); + self.discontinuous_last_message(); + } + Ok(()) + } - pub fn exit_agent_session(&mut self) -> Result<()> { - self.exit_session()?; - if let Some(agent) = self.agent.as_mut() { - agent.exit_session(); - if self.working_mode.is_repl() { - self.init_agent_shared_variables()?; - } - } - Ok(()) - } + pub fn exit_agent_session(&mut self) -> Result<()> { + self.exit_session()?; + if let Some(agent) = self.agent.as_mut() { + agent.exit_session(); + if self.working_mode.is_repl() { + self.init_agent_shared_variables()?; + } + } + Ok(()) + } - pub fn list_macros() -> Vec { - list_file_names(Self::macros_dir(), ".yaml") - } + pub fn list_macros() -> Vec { + list_file_names(Self::macros_dir(), ".yaml") + } - pub fn load_macro(name: &str) -> Result { - let path = Self::macro_file(name); - let err = || format!("Failed to load macro '{name}' at '{}'", path.display()); - let content = read_to_string(&path).with_context(err)?; - let value: Macro = serde_yaml::from_str(&content).with_context(err)?; - Ok(value) - } + pub fn load_macro(name: &str) -> Result { + let path = Self::macro_file(name); + let err = || format!("Failed to load macro '{name}' at '{}'", path.display()); + let content = read_to_string(&path).with_context(err)?; + let value: Macro = serde_yaml::from_str(&content).with_context(err)?; + Ok(value) + } - pub fn has_macro(name: &str) -> bool { - let names = Self::list_macros(); - names.contains(&name.to_string()) - } + pub fn has_macro(name: &str) -> bool { + let names = Self::list_macros(); + names.contains(&name.to_string()) + } - pub fn new_macro(&mut self, name: &str) -> Result<()> { - if self.macro_flag { - bail!("No macro"); - } - let ans = Confirm::new("Create a new macro?") - .with_default(true) - .prompt()?; - if ans { - let macro_path = Self::macro_file(name); - ensure_parent_exists(¯o_path)?; - let editor = self.editor()?; - edit_file(&editor, ¯o_path)?; - } else { - bail!("No macro"); - } - Ok(()) - } + pub fn new_macro(&mut self, name: &str) -> Result<()> { + if self.macro_flag { + bail!("No macro"); + } + let ans = Confirm::new("Create a new macro?") + .with_default(true) + .prompt()?; + if ans { + let macro_path = Self::macro_file(name); + ensure_parent_exists(¯o_path)?; + let editor = self.editor()?; + edit_file(&editor, ¯o_path)?; + } else { + bail!("No macro"); + } + Ok(()) + } - pub async fn apply_prelude(&mut self, abort_signal: AbortSignal) -> Result<()> { - if self.macro_flag || !self.state().is_empty() { - return Ok(()); - } - let prelude = match self.working_mode { - WorkingMode::Repl => self.repl_prelude.as_ref(), - WorkingMode::Cmd => self.cmd_prelude.as_ref(), - WorkingMode::Serve => return Ok(()), - }; - let prelude = match prelude { - Some(v) => { - if v.is_empty() { - return Ok(()); - } - v.to_string() - } - None => return Ok(()), - }; + pub async fn apply_prelude(&mut self, abort_signal: AbortSignal) -> Result<()> { + if self.macro_flag || !self.state().is_empty() { + return Ok(()); + } + let prelude = match self.working_mode { + WorkingMode::Repl => self.repl_prelude.as_ref(), + WorkingMode::Cmd => self.cmd_prelude.as_ref(), + WorkingMode::Serve => return Ok(()), + }; + let prelude = match prelude { + Some(v) => { + if v.is_empty() { + return Ok(()); + } + v.to_string() + } + None => return Ok(()), + }; - let err_msg = || format!("Invalid prelude '{prelude}"); - match prelude.split_once(':') { - Some(("role", name)) => { - self.use_role(name, abort_signal) - .await - .with_context(err_msg)?; - } - Some(("session", name)) => { - self.use_session(Some(name), abort_signal) - .await - .with_context(err_msg)?; - } - Some((session_name, role_name)) => { - self.use_session(Some(session_name), abort_signal.clone()) - .await - .with_context(err_msg)?; - if let Some(true) = self.session.as_ref().map(|v| v.is_empty()) { - self.use_role(role_name, abort_signal) - .await - .with_context(err_msg)?; - } - } - _ => { - bail!("{}", err_msg()) - } - } - Ok(()) - } + let err_msg = || format!("Invalid prelude '{prelude}"); + match prelude.split_once(':') { + Some(("role", name)) => { + self.use_role(name, abort_signal) + .await + .with_context(err_msg)?; + } + Some(("session", name)) => { + self.use_session(Some(name), abort_signal) + .await + .with_context(err_msg)?; + } + Some((session_name, role_name)) => { + self.use_session(Some(session_name), abort_signal.clone()) + .await + .with_context(err_msg)?; + if let Some(true) = self.session.as_ref().map(|v| v.is_empty()) { + self.use_role(role_name, abort_signal) + .await + .with_context(err_msg)?; + } + } + _ => { + bail!("{}", err_msg()) + } + } + Ok(()) + } - pub fn select_functions(&self, role: &Role) -> Option> { - let mut functions = vec![]; - functions.extend(self.select_enabled_functions(role)); - functions.extend(self.select_enabled_mcp_servers(role)); + pub fn select_functions(&self, role: &Role) -> Option> { + let mut functions = vec![]; + functions.extend(self.select_enabled_functions(role)); + functions.extend(self.select_enabled_mcp_servers(role)); - if functions.is_empty() { - None - } else { - Some(functions) - } - } + if functions.is_empty() { + None + } else { + Some(functions) + } + } - fn select_enabled_functions(&self, role: &Role) -> Vec { - let mut functions = vec![]; - if self.function_calling { - if let Some(use_tools) = role.use_tools() { - let mut tool_names: HashSet = Default::default(); - let declaration_names: HashSet = self - .functions - .declarations() - .iter() - .map(|v| v.name.to_string()) - .collect(); - if use_tools == "all" { - tool_names.extend(declaration_names); - } else { - for item in use_tools.split(',') { - let item = item.trim(); - if let Some(values) = self.mapping_tools.get(item) { - tool_names.extend( - values - .split(',') - .map(|v| v.to_string()) - .filter(|v| declaration_names.contains(v)), - ) - } else if declaration_names.contains(item) { - tool_names.insert(item.to_string()); - } - } - } - functions = self - .functions - .declarations() - .iter() - .filter_map(|v| { - if tool_names.contains(&v.name) { - Some(v.clone()) - } else { - None - } - }) - .collect(); - } + fn select_enabled_functions(&self, role: &Role) -> Vec { + let mut functions = vec![]; + if self.function_calling { + if let Some(use_tools) = role.use_tools() { + let mut tool_names: HashSet = Default::default(); + let declaration_names: HashSet = self + .functions + .declarations() + .iter() + .map(|v| v.name.to_string()) + .collect(); + if use_tools == "all" { + tool_names.extend(declaration_names); + } else { + for item in use_tools.split(',') { + let item = item.trim(); + if let Some(values) = self.mapping_tools.get(item) { + tool_names.extend( + values + .split(',') + .map(|v| v.to_string()) + .filter(|v| declaration_names.contains(v)), + ) + } else if declaration_names.contains(item) { + tool_names.insert(item.to_string()); + } + } + } + functions = self + .functions + .declarations() + .iter() + .filter_map(|v| { + if tool_names.contains(&v.name) { + Some(v.clone()) + } else { + None + } + }) + .collect(); + } - if let Some(agent) = &self.agent { - let mut agent_functions = agent.functions().declarations().to_vec(); - let tool_names: HashSet = agent_functions - .iter() - .filter_map(|v| { - if v.agent { - None - } else { - Some(v.name.to_string()) - } - }) - .collect(); - agent_functions.extend( - functions - .into_iter() - .filter(|v| !tool_names.contains(&v.name)), - ); - functions = agent_functions; - } - } + if let Some(agent) = &self.agent { + let mut agent_functions = agent.functions().declarations().to_vec(); + let tool_names: HashSet = agent_functions + .iter() + .filter_map(|v| { + if v.agent { + None + } else { + Some(v.name.to_string()) + } + }) + .collect(); + agent_functions.extend( + functions + .into_iter() + .filter(|v| !tool_names.contains(&v.name)), + ); + functions = agent_functions; + } + } - functions - } + functions + } - fn select_enabled_mcp_servers(&self, role: &Role) -> Vec { - let mut mcp_functions = vec![]; - if self.mcp_servers { - if let Some(use_mcp_servers) = role.use_mcp_servers() { - let mut server_names: HashSet = Default::default(); - let mcp_declaration_names: HashSet = self - .functions - .declarations() - .iter() - .filter(|v| { - v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) - }) - .map(|v| v.name.to_string()) - .collect(); - if use_mcp_servers == "all" { - server_names.extend(mcp_declaration_names); - } else { - for item in use_mcp_servers.split(',') { - let item = item.trim(); - let item_invoke_name = - format!("{}_{item}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); - let item_list_name = - format!("{}_{item}", MCP_LIST_META_FUNCTION_NAME_PREFIX); - if let Some(values) = self.mapping_tools.get(item) { - server_names.extend( - values - .split(',') - .flat_map(|v| { - vec![ - format!( - "{}_{}", - MCP_INVOKE_META_FUNCTION_NAME_PREFIX, - v.to_string() - ), - format!( - "{}_{}", - MCP_LIST_META_FUNCTION_NAME_PREFIX, - v.to_string() - ), - ] - }) - .filter(|v| mcp_declaration_names.contains(v)), - ) - } else if mcp_declaration_names.contains(&item_invoke_name) { - server_names.insert(item_invoke_name); - server_names.insert(item_list_name); - } - } - } - mcp_functions = self - .functions - .declarations() - .iter() - .filter_map(|v| { - if server_names.contains(&v.name) { - Some(v.clone()) - } else { - None - } - }) - .collect(); - } + fn select_enabled_mcp_servers(&self, role: &Role) -> Vec { + let mut mcp_functions = vec![]; + if self.mcp_servers { + if let Some(use_mcp_servers) = role.use_mcp_servers() { + let mut server_names: HashSet = Default::default(); + let mcp_declaration_names: HashSet = self + .functions + .declarations() + .iter() + .filter(|v| { + v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + }) + .map(|v| v.name.to_string()) + .collect(); + if use_mcp_servers == "all" { + server_names.extend(mcp_declaration_names); + } else { + for item in use_mcp_servers.split(',') { + let item = item.trim(); + let item_invoke_name = + format!("{}_{item}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); + let item_list_name = + format!("{}_{item}", MCP_LIST_META_FUNCTION_NAME_PREFIX); + if let Some(values) = self.mapping_tools.get(item) { + server_names.extend( + values + .split(',') + .flat_map(|v| { + vec![ + format!( + "{}_{}", + MCP_INVOKE_META_FUNCTION_NAME_PREFIX, + v.to_string() + ), + format!( + "{}_{}", + MCP_LIST_META_FUNCTION_NAME_PREFIX, + v.to_string() + ), + ] + }) + .filter(|v| mcp_declaration_names.contains(v)), + ) + } else if mcp_declaration_names.contains(&item_invoke_name) { + server_names.insert(item_invoke_name); + server_names.insert(item_list_name); + } + } + } + mcp_functions = self + .functions + .declarations() + .iter() + .filter_map(|v| { + if server_names.contains(&v.name) { + Some(v.clone()) + } else { + None + } + }) + .collect(); + } - if let Some(agent) = &self.agent { - let mut agent_functions = agent.functions().declarations().to_vec(); - let tool_names: HashSet = agent_functions - .iter() - .filter_map(|v| { - if v.agent { - None - } else { - Some(v.name.to_string()) - } - }) - .collect(); - agent_functions.extend( - mcp_functions - .into_iter() - .filter(|v| !tool_names.contains(&v.name)), - ); - mcp_functions = agent_functions; - } - } + if let Some(agent) = &self.agent { + let mut agent_functions = agent.functions().declarations().to_vec(); + let tool_names: HashSet = agent_functions + .iter() + .filter_map(|v| { + if v.agent { + None + } else { + Some(v.name.to_string()) + } + }) + .collect(); + agent_functions.extend( + mcp_functions + .into_iter() + .filter(|v| !tool_names.contains(&v.name)), + ); + mcp_functions = agent_functions; + } + } - mcp_functions - } + mcp_functions + } - pub fn editor(&self) -> Result { - EDITOR.get_or_init(move || { + pub fn editor(&self) -> Result { + EDITOR.get_or_init(move || { let editor = self.editor.clone() .or_else(|| env::var("VISUAL").ok().or_else(|| env::var("EDITOR").ok())) .unwrap_or_else(|| { @@ -2143,925 +2143,925 @@ impl Config { }) .clone() .ok_or_else(|| anyhow!("Editor not found. Please add the `editor` configuration or set the $EDITOR or $VISUAL environment variable.")) - } + } - pub fn repl_complete( - &self, - cmd: &str, - args: &[&str], - _line: &str, - ) -> Vec<(String, Option)> { - let mut values: Vec<(String, Option)> = vec![]; - let filter = args.last().unwrap_or(&""); - if args.len() == 1 { - values = match cmd { - ".role" => map_completion_values(Self::list_roles(true)), - ".model" => list_models(self, ModelType::Chat) - .into_iter() - .map(|v| (v.id(), Some(v.description()))) - .collect(), - ".session" => { - if args[0].starts_with("_/") { - map_completion_values( - self.list_autoname_sessions() - .iter() - .rev() - .map(|v| format!("_/{v}")) - .collect::>(), - ) - } else { - map_completion_values(self.list_sessions()) - } - } - ".rag" => map_completion_values(Self::list_rags()), - ".agent" => map_completion_values(list_agents()), - ".macro" => map_completion_values(Self::list_macros()), - ".starter" => match &self.agent { - Some(agent) => agent - .conversation_starters() - .iter() - .enumerate() - .map(|(i, v)| ((i + 1).to_string(), Some(v.to_string()))) - .collect(), - None => vec![], - }, - ".set" => { - let mut values = vec![ - "temperature", - "top_p", - "use_tools", - "use_mcp_servers", - "save_session", - "compress_threshold", - "rag_reranker_model", - "rag_top_k", - "max_output_tokens", - "dry_run", - "function_calling", - "mcp_servers", - "stream", - "save", - "highlight", - ]; - values.sort_unstable(); - values - .into_iter() - .map(|v| (format!("{v} "), None)) - .collect() - } - ".delete" => { - map_completion_values(vec!["role", "session", "rag", "macro", "agent-data"]) - } - ".vault" => { - let mut values = vec!["add", "get", "update", "delete", "list"]; - values.sort_unstable(); - values - .into_iter() - .map(|v| (format!("{v} "), None)) - .collect() - } - _ => vec![], - }; - } else if cmd == ".set" && args.len() == 2 { - let candidates = match args[0] { - "max_output_tokens" => match self.current_model().max_output_tokens() { - Some(v) => vec![v.to_string()], - None => vec![], - }, - "dry_run" => complete_bool(self.dry_run), - "stream" => complete_bool(self.stream), - "save" => complete_bool(self.save), - "function_calling" => complete_bool(self.function_calling), - "use_tools" => { - let mut prefix = String::new(); - let mut ignores = HashSet::new(); - if let Some((v, _)) = args[1].rsplit_once(',') { - ignores = v.split(',').collect(); - prefix = format!("{v},"); - } - let mut values = vec![]; - if prefix.is_empty() { - values.push("all".to_string()); - } - values.extend(self.functions.declarations().iter().map(|v| v.name.clone())); - values.extend(self.mapping_tools.keys().map(|v| v.to_string())); - values - .into_iter() - .filter(|v| !ignores.contains(v.as_str())) - .map(|v| format!("{prefix}{v}")) - .collect() - } - "mcp_servers" => complete_bool(self.mcp_servers), - "use_mcp_servers" => { - let mut prefix = String::new(); - let mut ignores = HashSet::new(); - if let Some((v, _)) = args[1].rsplit_once(',') { - ignores = v.split(',').collect(); - prefix = format!("{v},"); - } - let mut values = vec![]; - if prefix.is_empty() { - values.push("all".to_string()); - } + pub fn repl_complete( + &self, + cmd: &str, + args: &[&str], + _line: &str, + ) -> Vec<(String, Option)> { + let mut values: Vec<(String, Option)> = vec![]; + let filter = args.last().unwrap_or(&""); + if args.len() == 1 { + values = match cmd { + ".role" => map_completion_values(Self::list_roles(true)), + ".model" => list_models(self, ModelType::Chat) + .into_iter() + .map(|v| (v.id(), Some(v.description()))) + .collect(), + ".session" => { + if args[0].starts_with("_/") { + map_completion_values( + self.list_autoname_sessions() + .iter() + .rev() + .map(|v| format!("_/{v}")) + .collect::>(), + ) + } else { + map_completion_values(self.list_sessions()) + } + } + ".rag" => map_completion_values(Self::list_rags()), + ".agent" => map_completion_values(list_agents()), + ".macro" => map_completion_values(Self::list_macros()), + ".starter" => match &self.agent { + Some(agent) => agent + .conversation_starters() + .iter() + .enumerate() + .map(|(i, v)| ((i + 1).to_string(), Some(v.to_string()))) + .collect(), + None => vec![], + }, + ".set" => { + let mut values = vec![ + "temperature", + "top_p", + "use_tools", + "use_mcp_servers", + "save_session", + "compress_threshold", + "rag_reranker_model", + "rag_top_k", + "max_output_tokens", + "dry_run", + "function_calling", + "mcp_servers", + "stream", + "save", + "highlight", + ]; + values.sort_unstable(); + values + .into_iter() + .map(|v| (format!("{v} "), None)) + .collect() + } + ".delete" => { + map_completion_values(vec!["role", "session", "rag", "macro", "agent-data"]) + } + ".vault" => { + let mut values = vec!["add", "get", "update", "delete", "list"]; + values.sort_unstable(); + values + .into_iter() + .map(|v| (format!("{v} "), None)) + .collect() + } + _ => vec![], + }; + } else if cmd == ".set" && args.len() == 2 { + let candidates = match args[0] { + "max_output_tokens" => match self.current_model().max_output_tokens() { + Some(v) => vec![v.to_string()], + None => vec![], + }, + "dry_run" => complete_bool(self.dry_run), + "stream" => complete_bool(self.stream), + "save" => complete_bool(self.save), + "function_calling" => complete_bool(self.function_calling), + "use_tools" => { + let mut prefix = String::new(); + let mut ignores = HashSet::new(); + if let Some((v, _)) = args[1].rsplit_once(',') { + ignores = v.split(',').collect(); + prefix = format!("{v},"); + } + let mut values = vec![]; + if prefix.is_empty() { + values.push("all".to_string()); + } + values.extend(self.functions.declarations().iter().map(|v| v.name.clone())); + values.extend(self.mapping_tools.keys().map(|v| v.to_string())); + values + .into_iter() + .filter(|v| !ignores.contains(v.as_str())) + .map(|v| format!("{prefix}{v}")) + .collect() + } + "mcp_servers" => complete_bool(self.mcp_servers), + "use_mcp_servers" => { + let mut prefix = String::new(); + let mut ignores = HashSet::new(); + if let Some((v, _)) = args[1].rsplit_once(',') { + ignores = v.split(',').collect(); + prefix = format!("{v},"); + } + let mut values = vec![]; + if prefix.is_empty() { + values.push("all".to_string()); + } - if let Some(registry) = &self.mcp_registry { - values.extend(registry.list_configured_servers()); - } + if let Some(registry) = &self.mcp_registry { + values.extend(registry.list_configured_servers()); + } - values.extend(self.mapping_mcp_servers.keys().map(|v| v.to_string())); - values - .into_iter() - .filter(|v| !ignores.contains(v.as_str())) - .map(|v| format!("{prefix}{v}")) - .collect() - } - "save_session" => { - let save_session = if let Some(session) = &self.session { - session.save_session() - } else { - self.save_session - }; - complete_option_bool(save_session) - } - "rag_reranker_model" => list_models(self, ModelType::Reranker) - .iter() - .map(|v| v.id()) - .collect(), - "highlight" => complete_bool(self.highlight), - _ => vec![], - }; - values = candidates.into_iter().map(|v| (v, None)).collect(); - } else if cmd == ".vault" && args.len() == 2 { - values = self - .vault - .list_secrets(false) - .unwrap_or_default() - .into_iter() - .map(|v| (v, None)) - .collect(); - } else if cmd == ".agent" { - if args.len() == 2 { - let dir = Self::agent_data_dir(args[0]).join(SESSIONS_DIR_NAME); - values = list_file_names(dir, ".yaml") - .into_iter() - .map(|v| (v, None)) - .collect(); - } - values.extend(complete_agent_variables(args[0])); - }; - fuzzy_filter(values, |v| v.0.as_str(), filter) - } + values.extend(self.mapping_mcp_servers.keys().map(|v| v.to_string())); + values + .into_iter() + .filter(|v| !ignores.contains(v.as_str())) + .map(|v| format!("{prefix}{v}")) + .collect() + } + "save_session" => { + let save_session = if let Some(session) = &self.session { + session.save_session() + } else { + self.save_session + }; + complete_option_bool(save_session) + } + "rag_reranker_model" => list_models(self, ModelType::Reranker) + .iter() + .map(|v| v.id()) + .collect(), + "highlight" => complete_bool(self.highlight), + _ => vec![], + }; + values = candidates.into_iter().map(|v| (v, None)).collect(); + } else if cmd == ".vault" && args.len() == 2 { + values = self + .vault + .list_secrets(false) + .unwrap_or_default() + .into_iter() + .map(|v| (v, None)) + .collect(); + } else if cmd == ".agent" { + if args.len() == 2 { + let dir = Self::agent_data_dir(args[0]).join(SESSIONS_DIR_NAME); + values = list_file_names(dir, ".yaml") + .into_iter() + .map(|v| (v, None)) + .collect(); + } + values.extend(complete_agent_variables(args[0])); + }; + fuzzy_filter(values, |v| v.0.as_str(), filter) + } - pub fn sync_models_url(&self) -> String { - self.sync_models_url - .clone() - .unwrap_or_else(|| SYNC_MODELS_URL.into()) - } + pub fn sync_models_url(&self) -> String { + self.sync_models_url + .clone() + .unwrap_or_else(|| SYNC_MODELS_URL.into()) + } - pub async fn sync_models(url: &str, abort_signal: AbortSignal) -> Result<()> { - let content = abortable_run_with_spinner(fetch(url), "Fetching models.yaml", abort_signal) - .await - .with_context(|| format!("Failed to fetch '{url}'"))?; - println!("✓ Fetched '{url}'"); - let list = serde_yaml::from_str::>(&content) - .with_context(|| "Failed to parse models.yaml")?; - let models_override = ModelsOverride { - version: env!("CARGO_PKG_VERSION").to_string(), - list, - }; - let models_override_data = - serde_yaml::to_string(&models_override).with_context(|| "Failed to serde {}")?; + pub async fn sync_models(url: &str, abort_signal: AbortSignal) -> Result<()> { + let content = abortable_run_with_spinner(fetch(url), "Fetching models.yaml", abort_signal) + .await + .with_context(|| format!("Failed to fetch '{url}'"))?; + println!("✓ Fetched '{url}'"); + let list = serde_yaml::from_str::>(&content) + .with_context(|| "Failed to parse models.yaml")?; + let models_override = ModelsOverride { + version: env!("CARGO_PKG_VERSION").to_string(), + list, + }; + let models_override_data = + serde_yaml::to_string(&models_override).with_context(|| "Failed to serde {}")?; - let model_override_path = Self::models_override_file(); - ensure_parent_exists(&model_override_path)?; - std::fs::write(&model_override_path, models_override_data) - .with_context(|| format!("Failed to write to '{}'", model_override_path.display()))?; - println!("✓ Updated '{}'", model_override_path.display()); - Ok(()) - } + let model_override_path = Self::models_override_file(); + ensure_parent_exists(&model_override_path)?; + std::fs::write(&model_override_path, models_override_data) + .with_context(|| format!("Failed to write to '{}'", model_override_path.display()))?; + println!("✓ Updated '{}'", model_override_path.display()); + Ok(()) + } - pub fn local_models_override() -> Result> { - let model_override_path = Self::models_override_file(); - let err = || { - format!( - "Failed to load models at '{}'", - model_override_path.display() - ) - }; - let content = read_to_string(&model_override_path).with_context(err)?; - let models_override: ModelsOverride = serde_yaml::from_str(&content).with_context(err)?; - if models_override.version != env!("CARGO_PKG_VERSION") { - bail!("Incompatible version") - } - Ok(models_override.list) - } + pub fn local_models_override() -> Result> { + let model_override_path = Self::models_override_file(); + let err = || { + format!( + "Failed to load models at '{}'", + model_override_path.display() + ) + }; + let content = read_to_string(&model_override_path).with_context(err)?; + let models_override: ModelsOverride = serde_yaml::from_str(&content).with_context(err)?; + if models_override.version != env!("CARGO_PKG_VERSION") { + bail!("Incompatible version") + } + Ok(models_override.list) + } - pub fn light_theme(&self) -> bool { - matches!(self.theme.as_deref(), Some("light")) - } + pub fn light_theme(&self) -> bool { + matches!(self.theme.as_deref(), Some("light")) + } - pub fn render_options(&self) -> Result { - let theme = if self.highlight { - let theme_mode = if self.light_theme() { "light" } else { "dark" }; - let theme_filename = format!("{theme_mode}.tmTheme"); - let theme_path = Self::local_path(&theme_filename); - if theme_path.exists() { - let theme = ThemeSet::get_theme(&theme_path) - .with_context(|| format!("Invalid theme at '{}'", theme_path.display()))?; - Some(theme) - } else { - let theme = if self.light_theme() { - decode_bin(LIGHT_THEME).context("Invalid builtin light theme")? - } else { - decode_bin(DARK_THEME).context("Invalid builtin dark theme")? - }; - Some(theme) - } - } else { - None - }; - let wrap = if *IS_STDOUT_TERMINAL { - self.wrap.clone() - } else { - None - }; - let truecolor = matches!( + pub fn render_options(&self) -> Result { + let theme = if self.highlight { + let theme_mode = if self.light_theme() { "light" } else { "dark" }; + let theme_filename = format!("{theme_mode}.tmTheme"); + let theme_path = Self::local_path(&theme_filename); + if theme_path.exists() { + let theme = ThemeSet::get_theme(&theme_path) + .with_context(|| format!("Invalid theme at '{}'", theme_path.display()))?; + Some(theme) + } else { + let theme = if self.light_theme() { + decode_bin(LIGHT_THEME).context("Invalid builtin light theme")? + } else { + decode_bin(DARK_THEME).context("Invalid builtin dark theme")? + }; + Some(theme) + } + } else { + None + }; + let wrap = if *IS_STDOUT_TERMINAL { + self.wrap.clone() + } else { + None + }; + let truecolor = matches!( env::var("COLORTERM").as_ref().map(|v| v.as_str()), Ok("truecolor") ); - Ok(RenderOptions::new(theme, wrap, self.wrap_code, truecolor)) - } + Ok(RenderOptions::new(theme, wrap, self.wrap_code, truecolor)) + } - pub fn render_prompt_left(&self) -> String { - let variables = self.generate_prompt_context(); - let left_prompt = self.left_prompt.as_deref().unwrap_or(LEFT_PROMPT); - render_prompt(left_prompt, &variables) - } + pub fn render_prompt_left(&self) -> String { + let variables = self.generate_prompt_context(); + let left_prompt = self.left_prompt.as_deref().unwrap_or(LEFT_PROMPT); + render_prompt(left_prompt, &variables) + } - pub fn render_prompt_right(&self) -> String { - let variables = self.generate_prompt_context(); - let right_prompt = self.right_prompt.as_deref().unwrap_or(RIGHT_PROMPT); - render_prompt(right_prompt, &variables) - } + pub fn render_prompt_right(&self) -> String { + let variables = self.generate_prompt_context(); + let right_prompt = self.right_prompt.as_deref().unwrap_or(RIGHT_PROMPT); + render_prompt(right_prompt, &variables) + } - pub fn print_markdown(&self, text: &str) -> Result<()> { - if *IS_STDOUT_TERMINAL { - let render_options = self.render_options()?; - let mut markdown_render = MarkdownRender::init(render_options)?; - println!("{}", markdown_render.render(text)); - } else { - println!("{text}"); - } - Ok(()) - } + pub fn print_markdown(&self, text: &str) -> Result<()> { + if *IS_STDOUT_TERMINAL { + let render_options = self.render_options()?; + let mut markdown_render = MarkdownRender::init(render_options)?; + println!("{}", markdown_render.render(text)); + } else { + println!("{text}"); + } + Ok(()) + } - fn generate_prompt_context(&self) -> HashMap<&str, String> { - let mut output = HashMap::new(); - let role = self.extract_role(); - output.insert("model", role.model().id()); - output.insert("client_name", role.model().client_name().to_string()); - output.insert("model_name", role.model().name().to_string()); - output.insert( - "max_input_tokens", - role.model() - .max_input_tokens() - .unwrap_or_default() - .to_string(), - ); - if let Some(temperature) = role.temperature() { - if temperature != 0.0 { - output.insert("temperature", temperature.to_string()); - } - } - if let Some(top_p) = role.top_p() { - if top_p != 0.0 { - output.insert("top_p", top_p.to_string()); - } - } - if self.dry_run { - output.insert("dry_run", "true".to_string()); - } - if self.stream { - output.insert("stream", "true".to_string()); - } - if self.save { - output.insert("save", "true".to_string()); - } - if let Some(wrap) = &self.wrap { - if wrap != "no" { - output.insert("wrap", wrap.clone()); - } - } - if !role.is_derived() { - output.insert("role", role.name().to_string()); - } - if let Some(session) = &self.session { - output.insert("session", session.name().to_string()); - if let Some(autoname) = session.autoname() { - output.insert("session_autoname", autoname.to_string()); - } - output.insert("dirty", session.dirty().to_string()); - let (tokens, percent) = session.tokens_usage(); - output.insert("consume_tokens", tokens.to_string()); - output.insert("consume_percent", percent.to_string()); - output.insert("user_messages_len", session.user_messages_len().to_string()); - } - if let Some(rag) = &self.rag { - output.insert("rag", rag.name().to_string()); - } - if let Some(agent) = &self.agent { - output.insert("agent", agent.name().to_string()); - } + fn generate_prompt_context(&self) -> HashMap<&str, String> { + let mut output = HashMap::new(); + let role = self.extract_role(); + output.insert("model", role.model().id()); + output.insert("client_name", role.model().client_name().to_string()); + output.insert("model_name", role.model().name().to_string()); + output.insert( + "max_input_tokens", + role.model() + .max_input_tokens() + .unwrap_or_default() + .to_string(), + ); + if let Some(temperature) = role.temperature() { + if temperature != 0.0 { + output.insert("temperature", temperature.to_string()); + } + } + if let Some(top_p) = role.top_p() { + if top_p != 0.0 { + output.insert("top_p", top_p.to_string()); + } + } + if self.dry_run { + output.insert("dry_run", "true".to_string()); + } + if self.stream { + output.insert("stream", "true".to_string()); + } + if self.save { + output.insert("save", "true".to_string()); + } + if let Some(wrap) = &self.wrap { + if wrap != "no" { + output.insert("wrap", wrap.clone()); + } + } + if !role.is_derived() { + output.insert("role", role.name().to_string()); + } + if let Some(session) = &self.session { + output.insert("session", session.name().to_string()); + if let Some(autoname) = session.autoname() { + output.insert("session_autoname", autoname.to_string()); + } + output.insert("dirty", session.dirty().to_string()); + let (tokens, percent) = session.tokens_usage(); + output.insert("consume_tokens", tokens.to_string()); + output.insert("consume_percent", percent.to_string()); + output.insert("user_messages_len", session.user_messages_len().to_string()); + } + if let Some(rag) = &self.rag { + output.insert("rag", rag.name().to_string()); + } + if let Some(agent) = &self.agent { + output.insert("agent", agent.name().to_string()); + } - if self.highlight { - output.insert("color.reset", "\u{1b}[0m".to_string()); - output.insert("color.black", "\u{1b}[30m".to_string()); - output.insert("color.dark_gray", "\u{1b}[90m".to_string()); - output.insert("color.red", "\u{1b}[31m".to_string()); - output.insert("color.light_red", "\u{1b}[91m".to_string()); - output.insert("color.green", "\u{1b}[32m".to_string()); - output.insert("color.light_green", "\u{1b}[92m".to_string()); - output.insert("color.yellow", "\u{1b}[33m".to_string()); - output.insert("color.light_yellow", "\u{1b}[93m".to_string()); - output.insert("color.blue", "\u{1b}[34m".to_string()); - output.insert("color.light_blue", "\u{1b}[94m".to_string()); - output.insert("color.purple", "\u{1b}[35m".to_string()); - output.insert("color.light_purple", "\u{1b}[95m".to_string()); - output.insert("color.magenta", "\u{1b}[35m".to_string()); - output.insert("color.light_magenta", "\u{1b}[95m".to_string()); - output.insert("color.cyan", "\u{1b}[36m".to_string()); - output.insert("color.light_cyan", "\u{1b}[96m".to_string()); - output.insert("color.white", "\u{1b}[37m".to_string()); - output.insert("color.light_gray", "\u{1b}[97m".to_string()); - } + if self.highlight { + output.insert("color.reset", "\u{1b}[0m".to_string()); + output.insert("color.black", "\u{1b}[30m".to_string()); + output.insert("color.dark_gray", "\u{1b}[90m".to_string()); + output.insert("color.red", "\u{1b}[31m".to_string()); + output.insert("color.light_red", "\u{1b}[91m".to_string()); + output.insert("color.green", "\u{1b}[32m".to_string()); + output.insert("color.light_green", "\u{1b}[92m".to_string()); + output.insert("color.yellow", "\u{1b}[33m".to_string()); + output.insert("color.light_yellow", "\u{1b}[93m".to_string()); + output.insert("color.blue", "\u{1b}[34m".to_string()); + output.insert("color.light_blue", "\u{1b}[94m".to_string()); + output.insert("color.purple", "\u{1b}[35m".to_string()); + output.insert("color.light_purple", "\u{1b}[95m".to_string()); + output.insert("color.magenta", "\u{1b}[35m".to_string()); + output.insert("color.light_magenta", "\u{1b}[95m".to_string()); + output.insert("color.cyan", "\u{1b}[36m".to_string()); + output.insert("color.light_cyan", "\u{1b}[96m".to_string()); + output.insert("color.white", "\u{1b}[37m".to_string()); + output.insert("color.light_gray", "\u{1b}[97m".to_string()); + } - output - } + output + } - pub fn before_chat_completion(&mut self, input: &Input) -> Result<()> { - self.last_message = Some(LastMessage::new(input.clone(), String::new())); - Ok(()) - } + pub fn before_chat_completion(&mut self, input: &Input) -> Result<()> { + self.last_message = Some(LastMessage::new(input.clone(), String::new())); + Ok(()) + } - pub fn after_chat_completion( - &mut self, - input: &Input, - output: &str, - tool_results: &[ToolResult], - ) -> Result<()> { - if !tool_results.is_empty() { - return Ok(()); - } - self.last_message = Some(LastMessage::new(input.clone(), output.to_string())); - if !self.dry_run { - self.save_message(input, output)?; - } - Ok(()) - } + pub fn after_chat_completion( + &mut self, + input: &Input, + output: &str, + tool_results: &[ToolResult], + ) -> Result<()> { + if !tool_results.is_empty() { + return Ok(()); + } + self.last_message = Some(LastMessage::new(input.clone(), output.to_string())); + if !self.dry_run { + self.save_message(input, output)?; + } + Ok(()) + } - fn discontinuous_last_message(&mut self) { - if let Some(last_message) = self.last_message.as_mut() { - last_message.continuous = false; - } - } + fn discontinuous_last_message(&mut self) { + if let Some(last_message) = self.last_message.as_mut() { + last_message.continuous = false; + } + } - fn save_message(&mut self, input: &Input, output: &str) -> Result<()> { - let mut input = input.clone(); - input.clear_patch(); - if let Some(session) = input.session_mut(&mut self.session) { - session.add_message(&input, output)?; - return Ok(()); - } + fn save_message(&mut self, input: &Input, output: &str) -> Result<()> { + let mut input = input.clone(); + input.clear_patch(); + if let Some(session) = input.session_mut(&mut self.session) { + session.add_message(&input, output)?; + return Ok(()); + } - if !self.save { - return Ok(()); - } - let mut file = self.open_message_file()?; - if output.is_empty() && input.tool_calls().is_none() { - return Ok(()); - } - let now = now(); - let summary = input.summary(); - let raw_input = input.raw(); - let scope = if self.agent.is_none() { - let role_name = if input.role().is_derived() { - None - } else { - Some(input.role().name()) - }; - match (role_name, input.rag_name()) { - (Some(role), Some(rag_name)) => format!(" ({role}#{rag_name})"), - (Some(role), _) => format!(" ({role})"), - (None, Some(rag_name)) => format!(" (#{rag_name})"), - _ => String::new(), - } - } else { - String::new() - }; - let tool_calls = match input.tool_calls() { - Some(MessageContentToolCalls { - tool_results, text, .. - }) => { - let mut lines = vec!["".to_string()]; - if !text.is_empty() { - lines.push(text.clone()); - } - lines.push(serde_json::to_string(&tool_results).unwrap_or_default()); - lines.push("\n".to_string()); - lines.join("\n") - } - None => String::new(), - }; - let output = format!( + if !self.save { + return Ok(()); + } + let mut file = self.open_message_file()?; + if output.is_empty() && input.tool_calls().is_none() { + return Ok(()); + } + let now = now(); + let summary = input.summary(); + let raw_input = input.raw(); + let scope = if self.agent.is_none() { + let role_name = if input.role().is_derived() { + None + } else { + Some(input.role().name()) + }; + match (role_name, input.rag_name()) { + (Some(role), Some(rag_name)) => format!(" ({role}#{rag_name})"), + (Some(role), _) => format!(" ({role})"), + (None, Some(rag_name)) => format!(" (#{rag_name})"), + _ => String::new(), + } + } else { + String::new() + }; + let tool_calls = match input.tool_calls() { + Some(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut lines = vec!["".to_string()]; + if !text.is_empty() { + lines.push(text.clone()); + } + lines.push(serde_json::to_string(&tool_results).unwrap_or_default()); + lines.push("\n".to_string()); + lines.join("\n") + } + None => String::new(), + }; + let output = format!( "# CHAT: {summary} [{now}]{scope}\n{raw_input}\n--------\n{tool_calls}{output}\n--------\n\n", ); - file.write_all(output.as_bytes()) - .with_context(|| "Failed to save message") - } + file.write_all(output.as_bytes()) + .with_context(|| "Failed to save message") + } - fn init_agent_shared_variables(&mut self) -> Result<()> { - let agent = match self.agent.as_mut() { - Some(v) => v, - None => return Ok(()), - }; - if !agent.defined_variables().is_empty() && agent.shared_variables().is_empty() { - let new_variables = - Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; - agent.set_shared_variables(new_variables); - } - if !self.info_flag { - agent.update_shared_dynamic_instructions(false)?; - } - Ok(()) - } + fn init_agent_shared_variables(&mut self) -> Result<()> { + let agent = match self.agent.as_mut() { + Some(v) => v, + None => return Ok(()), + }; + if !agent.defined_variables().is_empty() && agent.shared_variables().is_empty() { + let new_variables = + Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; + agent.set_shared_variables(new_variables); + } + if !self.info_flag { + agent.update_shared_dynamic_instructions(false)?; + } + Ok(()) + } - fn init_agent_session_variables(&mut self, new_session: bool) -> Result<()> { - let (agent, session) = match (self.agent.as_mut(), self.session.as_mut()) { - (Some(agent), Some(session)) => (agent, session), - _ => return Ok(()), - }; - if new_session { - let shared_variables = agent.shared_variables().clone(); - let session_variables = - if !agent.defined_variables().is_empty() && shared_variables.is_empty() { - let new_variables = - Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; - agent.set_shared_variables(new_variables.clone()); - new_variables - } else { - shared_variables - }; - agent.set_session_variables(session_variables); - if !self.info_flag { - agent.update_session_dynamic_instructions(None)?; - } - session.sync_agent(agent); - } else { - let variables = session.agent_variables(); - agent.set_session_variables(variables.clone()); - agent.update_session_dynamic_instructions(Some( - session.agent_instructions().to_string(), - ))?; - } - Ok(()) - } + fn init_agent_session_variables(&mut self, new_session: bool) -> Result<()> { + let (agent, session) = match (self.agent.as_mut(), self.session.as_mut()) { + (Some(agent), Some(session)) => (agent, session), + _ => return Ok(()), + }; + if new_session { + let shared_variables = agent.shared_variables().clone(); + let session_variables = + if !agent.defined_variables().is_empty() && shared_variables.is_empty() { + let new_variables = + Agent::init_agent_variables(agent.defined_variables(), self.info_flag)?; + agent.set_shared_variables(new_variables.clone()); + new_variables + } else { + shared_variables + }; + agent.set_session_variables(session_variables); + if !self.info_flag { + agent.update_session_dynamic_instructions(None)?; + } + session.sync_agent(agent); + } else { + let variables = session.agent_variables(); + agent.set_session_variables(variables.clone()); + agent.update_session_dynamic_instructions(Some( + session.agent_instructions().to_string(), + ))?; + } + Ok(()) + } - fn open_message_file(&self) -> Result { - let path = self.messages_file(); - ensure_parent_exists(&path)?; - OpenOptions::new() - .create(true) - .append(true) - .open(&path) - .with_context(|| format!("Failed to create/append {}", path.display())) - } + fn open_message_file(&self) -> Result { + let path = self.messages_file(); + ensure_parent_exists(&path)?; + OpenOptions::new() + .create(true) + .append(true) + .open(&path) + .with_context(|| format!("Failed to create/append {}", path.display())) + } - fn load_from_file(config_path: &Path) -> Result<(Self, String)> { - let err = || format!("Failed to load config at '{}'", config_path.display()); - let content = read_to_string(config_path).with_context(err)?; - let config = Self::load_from_str(&content).with_context(err)?; + fn load_from_file(config_path: &Path) -> Result<(Self, String)> { + let err = || format!("Failed to load config at '{}'", config_path.display()); + let content = read_to_string(config_path).with_context(err)?; + let config = Self::load_from_str(&content).with_context(err)?; - Ok((config, content)) - } + Ok((config, content)) + } - fn load_from_str(content: &str) -> Result { - if PASSWORD_FILE_SECRET_RE.is_match(content)? { - bail!("secret injection cannot be done on the vault_password_file property"); - } + fn load_from_str(content: &str) -> Result { + if PASSWORD_FILE_SECRET_RE.is_match(content)? { + bail!("secret injection cannot be done on the vault_password_file property"); + } - let config: Self = serde_yaml::from_str(content) - .map_err(|err| { - let err_msg = err.to_string(); - let err_msg = if err_msg.starts_with(&format!("{CLIENTS_FIELD}: ")) { - // location is incorrect, get rid of it - err_msg - .split_once(" at line") - .map(|(v, _)| { - format!("{v} (Sorry for being unable to provide an exact location)") - }) - .unwrap_or_else(|| "clients: invalid value".into()) - } else { - err_msg - }; - anyhow!("{err_msg}") - }) - .with_context(|| "Failed to load config from str")?; + let config: Self = serde_yaml::from_str(content) + .map_err(|err| { + let err_msg = err.to_string(); + let err_msg = if err_msg.starts_with(&format!("{CLIENTS_FIELD}: ")) { + // location is incorrect, get rid of it + err_msg + .split_once(" at line") + .map(|(v, _)| { + format!("{v} (Sorry for being unable to provide an exact location)") + }) + .unwrap_or_else(|| "clients: invalid value".into()) + } else { + err_msg + }; + anyhow!("{err_msg}") + }) + .with_context(|| "Failed to load config from str")?; - Ok(config) - } + Ok(config) + } - fn load_dynamic(model_id: &str) -> Result { - let provider = match model_id.split_once(':') { - Some((v, _)) => v, - _ => model_id, - }; - let is_openai_compatible = OPENAI_COMPATIBLE_PROVIDERS - .into_iter() - .any(|(name, _)| provider == name); - let client = if is_openai_compatible { - json!({ "type": "openai-compatible", "name": provider }) - } else { - json!({ "type": provider }) - }; - let config = json!({ + fn load_dynamic(model_id: &str) -> Result { + let provider = match model_id.split_once(':') { + Some((v, _)) => v, + _ => model_id, + }; + let is_openai_compatible = OPENAI_COMPATIBLE_PROVIDERS + .into_iter() + .any(|(name, _)| provider == name); + let client = if is_openai_compatible { + json!({ "type": "openai-compatible", "name": provider }) + } else { + json!({ "type": provider }) + }; + let config = json!({ "model": model_id.to_string(), "save": false, "clients": vec![client], }); - let config = - serde_json::from_value(config).with_context(|| "Failed to load config from env")?; - Ok(config) - } + let config = + serde_json::from_value(config).with_context(|| "Failed to load config from env")?; + Ok(config) + } - fn load_envs(&mut self) { - if let Ok(v) = env::var(get_env_name("model")) { - self.model_id = v; - } - if let Some(v) = read_env_value::(&get_env_name("temperature")) { - self.temperature = v; - } - if let Some(v) = read_env_value::(&get_env_name("top_p")) { - self.top_p = v; - } + fn load_envs(&mut self) { + if let Ok(v) = env::var(get_env_name("model")) { + self.model_id = v; + } + if let Some(v) = read_env_value::(&get_env_name("temperature")) { + self.temperature = v; + } + if let Some(v) = read_env_value::(&get_env_name("top_p")) { + self.top_p = v; + } - if let Some(Some(v)) = read_env_bool(&get_env_name("dry_run")) { - self.dry_run = v; - } - if let Some(Some(v)) = read_env_bool(&get_env_name("stream")) { - self.stream = v; - } - if let Some(Some(v)) = read_env_bool(&get_env_name("save")) { - self.save = v; - } - if let Ok(v) = env::var(get_env_name("keybindings")) { - if v == "vi" { - self.keybindings = v; - } - } - if let Some(v) = read_env_value::(&get_env_name("editor")) { - self.editor = v; - } - if let Some(v) = read_env_value::(&get_env_name("wrap")) { - self.wrap = v; - } - if let Some(Some(v)) = read_env_bool(&get_env_name("wrap_code")) { - self.wrap_code = v; - } + if let Some(Some(v)) = read_env_bool(&get_env_name("dry_run")) { + self.dry_run = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("stream")) { + self.stream = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("save")) { + self.save = v; + } + if let Ok(v) = env::var(get_env_name("keybindings")) { + if v == "vi" { + self.keybindings = v; + } + } + if let Some(v) = read_env_value::(&get_env_name("editor")) { + self.editor = v; + } + if let Some(v) = read_env_value::(&get_env_name("wrap")) { + self.wrap = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("wrap_code")) { + self.wrap_code = v; + } - if let Some(Some(v)) = read_env_bool(&get_env_name("function_calling")) { - self.function_calling = v; - } - if let Ok(v) = env::var(get_env_name("mapping_tools")) { - if let Ok(v) = serde_json::from_str(&v) { - self.mapping_tools = v; - } - } - if let Some(v) = read_env_value::(&get_env_name("use_tools")) { - self.use_tools = v; - } + if let Some(Some(v)) = read_env_bool(&get_env_name("function_calling")) { + self.function_calling = v; + } + if let Ok(v) = env::var(get_env_name("mapping_tools")) { + if let Ok(v) = serde_json::from_str(&v) { + self.mapping_tools = v; + } + } + if let Some(v) = read_env_value::(&get_env_name("use_tools")) { + self.use_tools = v; + } - if let Some(v) = read_env_value::(&get_env_name("repl_prelude")) { - self.repl_prelude = v; - } - if let Some(v) = read_env_value::(&get_env_name("cmd_prelude")) { - self.cmd_prelude = v; - } - if let Some(v) = read_env_value::(&get_env_name("agent_session")) { - self.agent_session = v; - } + if let Some(v) = read_env_value::(&get_env_name("repl_prelude")) { + self.repl_prelude = v; + } + if let Some(v) = read_env_value::(&get_env_name("cmd_prelude")) { + self.cmd_prelude = v; + } + if let Some(v) = read_env_value::(&get_env_name("agent_session")) { + self.agent_session = v; + } - if let Some(v) = read_env_bool(&get_env_name("save_session")) { - self.save_session = v; - } - if let Some(Some(v)) = read_env_value::(&get_env_name("compress_threshold")) { - self.compress_threshold = v; - } - if let Some(v) = read_env_value::(&get_env_name("summarize_prompt")) { - self.summarize_prompt = v; - } - if let Some(v) = read_env_value::(&get_env_name("summary_prompt")) { - self.summary_prompt = v; - } + if let Some(v) = read_env_bool(&get_env_name("save_session")) { + self.save_session = v; + } + if let Some(Some(v)) = read_env_value::(&get_env_name("compress_threshold")) { + self.compress_threshold = v; + } + if let Some(v) = read_env_value::(&get_env_name("summarize_prompt")) { + self.summarize_prompt = v; + } + if let Some(v) = read_env_value::(&get_env_name("summary_prompt")) { + self.summary_prompt = v; + } - if let Some(v) = read_env_value::(&get_env_name("rag_embedding_model")) { - self.rag_embedding_model = v; - } - if let Some(v) = read_env_value::(&get_env_name("rag_reranker_model")) { - self.rag_reranker_model = v; - } - if let Some(Some(v)) = read_env_value::(&get_env_name("rag_top_k")) { - self.rag_top_k = v; - } - if let Some(v) = read_env_value::(&get_env_name("rag_chunk_size")) { - self.rag_chunk_size = v; - } - if let Some(v) = read_env_value::(&get_env_name("rag_chunk_overlap")) { - self.rag_chunk_overlap = v; - } - if let Some(v) = read_env_value::(&get_env_name("rag_template")) { - self.rag_template = v; - } + if let Some(v) = read_env_value::(&get_env_name("rag_embedding_model")) { + self.rag_embedding_model = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_reranker_model")) { + self.rag_reranker_model = v; + } + if let Some(Some(v)) = read_env_value::(&get_env_name("rag_top_k")) { + self.rag_top_k = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_size")) { + self.rag_chunk_size = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_chunk_overlap")) { + self.rag_chunk_overlap = v; + } + if let Some(v) = read_env_value::(&get_env_name("rag_template")) { + self.rag_template = v; + } - if let Ok(v) = env::var(get_env_name("document_loaders")) { - if let Ok(v) = serde_json::from_str(&v) { - self.document_loaders = v; - } - } + if let Ok(v) = env::var(get_env_name("document_loaders")) { + if let Ok(v) = serde_json::from_str(&v) { + self.document_loaders = v; + } + } - if let Some(Some(v)) = read_env_bool(&get_env_name("highlight")) { - self.highlight = v; - } - if *NO_COLOR { - self.highlight = false; - } - if self.highlight && self.theme.is_none() { - if let Some(v) = read_env_value::(&get_env_name("theme")) { - self.theme = v; - } else if *IS_STDOUT_TERMINAL { - if let Ok(color_scheme) = color_scheme(QueryOptions::default()) { - let theme = match color_scheme { - ColorScheme::Dark => "dark", - ColorScheme::Light => "light", - }; - self.theme = Some(theme.into()); - } - } - } - if let Some(v) = read_env_value::(&get_env_name("left_prompt")) { - self.left_prompt = v; - } - if let Some(v) = read_env_value::(&get_env_name("right_prompt")) { - self.right_prompt = v; - } + if let Some(Some(v)) = read_env_bool(&get_env_name("highlight")) { + self.highlight = v; + } + if *NO_COLOR { + self.highlight = false; + } + if self.highlight && self.theme.is_none() { + if let Some(v) = read_env_value::(&get_env_name("theme")) { + self.theme = v; + } else if *IS_STDOUT_TERMINAL { + if let Ok(color_scheme) = color_scheme(QueryOptions::default()) { + let theme = match color_scheme { + ColorScheme::Dark => "dark", + ColorScheme::Light => "light", + }; + self.theme = Some(theme.into()); + } + } + } + if let Some(v) = read_env_value::(&get_env_name("left_prompt")) { + self.left_prompt = v; + } + if let Some(v) = read_env_value::(&get_env_name("right_prompt")) { + self.right_prompt = v; + } - if let Some(v) = read_env_value::(&get_env_name("serve_addr")) { - self.serve_addr = v; - } - if let Some(v) = read_env_value::(&get_env_name("user_agent")) { - self.user_agent = v; - } - if let Some(Some(v)) = read_env_bool(&get_env_name("save_shell_history")) { - self.save_shell_history = v; - } - if let Some(v) = read_env_value::(&get_env_name("sync_models_url")) { - self.sync_models_url = v; - } - } + if let Some(v) = read_env_value::(&get_env_name("serve_addr")) { + self.serve_addr = v; + } + if let Some(v) = read_env_value::(&get_env_name("user_agent")) { + self.user_agent = v; + } + if let Some(Some(v)) = read_env_bool(&get_env_name("save_shell_history")) { + self.save_shell_history = v; + } + if let Some(v) = read_env_value::(&get_env_name("sync_models_url")) { + self.sync_models_url = v; + } + } - fn load_functions(&mut self) -> Result<()> { - self.functions = Functions::init()?; - Ok(()) - } + fn load_functions(&mut self) -> Result<()> { + self.functions = Functions::init()?; + Ok(()) + } - async fn load_mcp_servers( - &mut self, - log_path: Option, - start_mcp_servers: bool, - abort_signal: AbortSignal, - ) -> Result<()> { - let mcp_registry = McpRegistry::init( - log_path, - start_mcp_servers, - self.use_mcp_servers.clone(), - abort_signal.clone(), - self, - ) - .await?; - match mcp_registry.is_empty() { - false => { - if self.mcp_servers { - self.functions - .append_mcp_meta_functions(mcp_registry.list_started_servers()); - } else { - debug!( + async fn load_mcp_servers( + &mut self, + log_path: Option, + start_mcp_servers: bool, + abort_signal: AbortSignal, + ) -> Result<()> { + let mcp_registry = McpRegistry::init( + log_path, + start_mcp_servers, + self.use_mcp_servers.clone(), + abort_signal.clone(), + self, + ) + .await?; + match mcp_registry.is_empty() { + false => { + if self.mcp_servers { + self.functions + .append_mcp_meta_functions(mcp_registry.list_started_servers()); + } else { + debug!( "Skipping global MCP functions registration since mcp_servers was 'false'" ); - } - } - _ => debug!( + } + } + _ => debug!( "Skipping global MCP functions registration since start_mcp_servers was 'false'" ), - } - self.mcp_registry = Some(mcp_registry); + } + self.mcp_registry = Some(mcp_registry); - Ok(()) - } + Ok(()) + } - fn setup_model(&mut self) -> Result<()> { - let mut model_id = self.model_id.clone(); - if model_id.is_empty() { - let models = list_models(self, ModelType::Chat); - if models.is_empty() { - bail!("No available model"); - } - model_id = models[0].id() - } - self.set_model(&model_id)?; - self.model_id = model_id; + fn setup_model(&mut self) -> Result<()> { + let mut model_id = self.model_id.clone(); + if model_id.is_empty() { + let models = list_models(self, ModelType::Chat); + if models.is_empty() { + bail!("No available model"); + } + model_id = models[0].id() + } + self.set_model(&model_id)?; + self.model_id = model_id; - Ok(()) - } + Ok(()) + } - fn setup_document_loaders(&mut self) { - [("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")] - .into_iter() - .for_each(|(k, v)| { - let (k, v) = (k.to_string(), v.to_string()); - self.document_loaders.entry(k).or_insert(v); - }); - } + fn setup_document_loaders(&mut self) { + [("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")] + .into_iter() + .for_each(|(k, v)| { + let (k, v) = (k.to_string(), v.to_string()); + self.document_loaders.entry(k).or_insert(v); + }); + } - fn setup_user_agent(&mut self) { - if let Some("auto") = self.user_agent.as_deref() { - self.user_agent = Some(format!( - "{}/{}", - env!("CARGO_CRATE_NAME"), - env!("CARGO_PKG_VERSION") - )); - } - } + fn setup_user_agent(&mut self) { + if let Some("auto") = self.user_agent.as_deref() { + self.user_agent = Some(format!( + "{}/{}", + env!("CARGO_CRATE_NAME"), + env!("CARGO_PKG_VERSION") + )); + } + } } pub fn load_env_file() -> Result<()> { - let env_file_path = Config::env_file(); - let contents = match read_to_string(&env_file_path) { - Ok(v) => v, - Err(_) => return Ok(()), - }; - debug!("Use env file '{}'", env_file_path.display()); - for line in contents.lines() { - let line = line.trim(); - if line.starts_with('#') || line.is_empty() { - continue; - } - if let Some((key, value)) = line.split_once('=') { - env::set_var(key.trim(), value.trim()); - } - } - Ok(()) + let env_file_path = Config::env_file(); + let contents = match read_to_string(&env_file_path) { + Ok(v) => v, + Err(_) => return Ok(()), + }; + debug!("Use env file '{}'", env_file_path.display()); + for line in contents.lines() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + continue; + } + if let Some((key, value)) = line.split_once('=') { + env::set_var(key.trim(), value.trim()); + } + } + Ok(()) } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum WorkingMode { - Cmd, - Repl, - Serve, + Cmd, + Repl, + Serve, } impl WorkingMode { - pub fn is_cmd(&self) -> bool { - *self == WorkingMode::Cmd - } - pub fn is_repl(&self) -> bool { - *self == WorkingMode::Repl - } - pub fn is_serve(&self) -> bool { - *self == WorkingMode::Serve - } + pub fn is_cmd(&self) -> bool { + *self == WorkingMode::Cmd + } + pub fn is_repl(&self) -> bool { + *self == WorkingMode::Repl + } + pub fn is_serve(&self) -> bool { + *self == WorkingMode::Serve + } } #[async_recursion::async_recursion] pub async fn macro_execute( - config: &GlobalConfig, - name: &str, - args: Option<&str>, - abort_signal: AbortSignal, + config: &GlobalConfig, + name: &str, + args: Option<&str>, + abort_signal: AbortSignal, ) -> Result<()> { - let macro_value = Config::load_macro(name)?; - let (mut new_args, text) = split_args_text(args.unwrap_or_default(), cfg!(windows)); - if !text.is_empty() { - new_args.push(text.to_string()); - } - let variables = macro_value - .resolve_variables(&new_args) - .map_err(|err| anyhow!("{err}. Usage: {}", macro_value.usage(name)))?; - let role = config.read().extract_role(); - let mut config = config.read().clone(); - config.temperature = role.temperature(); - config.top_p = role.top_p(); - config.use_tools = role.use_tools().clone(); - config.use_mcp_servers = role.use_mcp_servers().clone(); - config.macro_flag = true; - config.model = role.model().clone(); - config.role = None; - config.session = None; - config.rag = None; - config.agent = None; - config.discontinuous_last_message(); - let config = Arc::new(RwLock::new(config)); - config.write().macro_flag = true; - for step in ¯o_value.steps { - let command = Macro::interpolate_command(step, &variables); - println!(">> {}", multiline_text(&command)); - run_repl_command(&config, abort_signal.clone(), &command).await?; - } - Ok(()) + let macro_value = Config::load_macro(name)?; + let (mut new_args, text) = split_args_text(args.unwrap_or_default(), cfg!(windows)); + if !text.is_empty() { + new_args.push(text.to_string()); + } + let variables = macro_value + .resolve_variables(&new_args) + .map_err(|err| anyhow!("{err}. Usage: {}", macro_value.usage(name)))?; + let role = config.read().extract_role(); + let mut config = config.read().clone(); + config.temperature = role.temperature(); + config.top_p = role.top_p(); + config.use_tools = role.use_tools().clone(); + config.use_mcp_servers = role.use_mcp_servers().clone(); + config.macro_flag = true; + config.model = role.model().clone(); + config.role = None; + config.session = None; + config.rag = None; + config.agent = None; + config.discontinuous_last_message(); + let config = Arc::new(RwLock::new(config)); + config.write().macro_flag = true; + for step in ¯o_value.steps { + let command = Macro::interpolate_command(step, &variables); + println!(">> {}", multiline_text(&command)); + run_repl_command(&config, abort_signal.clone(), &command).await?; + } + Ok(()) } #[derive(Debug, Clone, Deserialize)] pub struct Macro { - #[serde(default)] - pub variables: Vec, - pub steps: Vec, + #[serde(default)] + pub variables: Vec, + pub steps: Vec, } impl Macro { - pub fn resolve_variables(&self, args: &[String]) -> Result> { - let mut output = IndexMap::new(); - for (i, variable) in self.variables.iter().enumerate() { - let value = if variable.rest && i == self.variables.len() - 1 { - if args.len() > i { - Some(args[i..].join(" ")) - } else { - variable.default.clone() - } - } else { - args.get(i) - .map(|v| v.to_string()) - .or_else(|| variable.default.clone()) - }; - let value = - value.ok_or_else(|| anyhow!("Missing value for variable '{}'", variable.name))?; - output.insert(variable.name.clone(), value); - } - Ok(output) - } + pub fn resolve_variables(&self, args: &[String]) -> Result> { + let mut output = IndexMap::new(); + for (i, variable) in self.variables.iter().enumerate() { + let value = if variable.rest && i == self.variables.len() - 1 { + if args.len() > i { + Some(args[i..].join(" ")) + } else { + variable.default.clone() + } + } else { + args.get(i) + .map(|v| v.to_string()) + .or_else(|| variable.default.clone()) + }; + let value = + value.ok_or_else(|| anyhow!("Missing value for variable '{}'", variable.name))?; + output.insert(variable.name.clone(), value); + } + Ok(output) + } - pub fn usage(&self, name: &str) -> String { - let mut parts = vec![name.to_string()]; - for (i, variable) in self.variables.iter().enumerate() { - let part = match ( - variable.rest && i == self.variables.len() - 1, - variable.default.is_some(), - ) { - (true, true) => format!("[{}]...", variable.name), - (true, false) => format!("<{}>...", variable.name), - (false, true) => format!("[{}]", variable.name), - (false, false) => format!("<{}>", variable.name), - }; - parts.push(part); - } - parts.join(" ") - } + pub fn usage(&self, name: &str) -> String { + let mut parts = vec![name.to_string()]; + for (i, variable) in self.variables.iter().enumerate() { + let part = match ( + variable.rest && i == self.variables.len() - 1, + variable.default.is_some(), + ) { + (true, true) => format!("[{}]...", variable.name), + (true, false) => format!("<{}>...", variable.name), + (false, true) => format!("[{}]", variable.name), + (false, false) => format!("<{}>", variable.name), + }; + parts.push(part); + } + parts.join(" ") + } - pub fn interpolate_command(command: &str, variables: &IndexMap) -> String { - let mut output = command.to_string(); - for (key, value) in variables { - output = output.replace(&format!("{{{{{key}}}}}"), value); - } - output - } + pub fn interpolate_command(command: &str, variables: &IndexMap) -> String { + let mut output = command.to_string(); + for (key, value) in variables { + output = output.replace(&format!("{{{{{key}}}}}"), value); + } + output + } } #[derive(Debug, Clone, Deserialize)] pub struct MacroVariable { - pub name: String, - #[serde(default)] - pub rest: bool, - pub default: Option, + pub name: String, + #[serde(default)] + pub rest: bool, + pub default: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelsOverride { - pub version: String, - pub list: Vec, + pub version: String, + pub list: Vec, } #[derive(Debug, Clone)] pub struct LastMessage { - pub input: Input, - pub output: String, - pub continuous: bool, + pub input: Input, + pub output: String, + pub continuous: bool, } impl LastMessage { - pub fn new(input: Input, output: String) -> Self { - Self { - input, - output, - continuous: true, - } - } + pub fn new(input: Input, output: String) -> Self { + Self { + input, + output, + continuous: true, + } + } } bitflags::bitflags! { @@ -3077,152 +3077,152 @@ bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum AssertState { - True(StateFlags), - False(StateFlags), - TrueFalse(StateFlags, StateFlags), - Equal(StateFlags), + True(StateFlags), + False(StateFlags), + TrueFalse(StateFlags, StateFlags), + Equal(StateFlags), } impl AssertState { - pub fn pass() -> Self { - AssertState::False(StateFlags::empty()) - } + pub fn pass() -> Self { + AssertState::False(StateFlags::empty()) + } - pub fn bare() -> Self { - AssertState::Equal(StateFlags::empty()) - } + pub fn bare() -> Self { + AssertState::Equal(StateFlags::empty()) + } - pub fn assert(self, flags: StateFlags) -> bool { - match self { - AssertState::True(true_flags) => true_flags & flags != StateFlags::empty(), - AssertState::False(false_flags) => false_flags & flags == StateFlags::empty(), - AssertState::TrueFalse(true_flags, false_flags) => { - (true_flags & flags != StateFlags::empty()) - && (false_flags & flags == StateFlags::empty()) - } - AssertState::Equal(check_flags) => check_flags == flags, - } - } + pub fn assert(self, flags: StateFlags) -> bool { + match self { + AssertState::True(true_flags) => true_flags & flags != StateFlags::empty(), + AssertState::False(false_flags) => false_flags & flags == StateFlags::empty(), + AssertState::TrueFalse(true_flags, false_flags) => { + (true_flags & flags != StateFlags::empty()) + && (false_flags & flags == StateFlags::empty()) + } + AssertState::Equal(check_flags) => check_flags == flags, + } + } } async fn create_config_file(config_path: &Path) -> Result<()> { - let ans = Confirm::new("No config file, create a new one?") - .with_default(true) - .prompt()?; - if !ans { - process::exit(0); - } + let ans = Confirm::new("No config file, create a new one?") + .with_default(true) + .prompt()?; + if !ans { + process::exit(0); + } - let client = Select::new("API Provider (required):", list_client_types()).prompt()?; + let client = Select::new("API Provider (required):", list_client_types()).prompt()?; - let mut config = json!({}); - let (model, clients_config) = create_client_config(client).await?; - config["model"] = model.into(); - config[CLIENTS_FIELD] = clients_config; + let mut config = json!({}); + let (model, clients_config) = create_client_config(client).await?; + config["model"] = model.into(); + config[CLIENTS_FIELD] = clients_config; - let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?; - let config_data = format!( - "# see https://github.com/Dark-Alex-17/loki/blob/main/config.example.yaml\n\n{config_data}" - ); + let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?; + let config_data = format!( + "# see https://github.com/Dark-Alex-17/loki/blob/main/config.example.yaml\n\n{config_data}" + ); - ensure_parent_exists(config_path)?; - std::fs::write(config_path, config_data) - .with_context(|| format!("Failed to write to '{}'", config_path.display()))?; - #[cfg(unix)] - { - use std::os::unix::prelude::PermissionsExt; - let perms = std::fs::Permissions::from_mode(0o600); - std::fs::set_permissions(config_path, perms)?; - } + ensure_parent_exists(config_path)?; + std::fs::write(config_path, config_data) + .with_context(|| format!("Failed to write to '{}'", config_path.display()))?; + #[cfg(unix)] + { + use std::os::unix::prelude::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(config_path, perms)?; + } - println!("✓ Saved the config file to '{}'.\n", config_path.display()); + println!("✓ Saved the config file to '{}'.\n", config_path.display()); - Ok(()) + Ok(()) } pub(crate) fn ensure_parent_exists(path: &Path) -> Result<()> { - if path.exists() { - return Ok(()); - } - let parent = path - .parent() - .ok_or_else(|| anyhow!("Failed to write to '{}', No parent path", path.display()))?; - if !parent.exists() { - create_dir_all(parent).with_context(|| { - format!( - "Failed to write to '{}', Cannot create parent directory", - path.display() - ) - })?; - } - Ok(()) + if path.exists() { + return Ok(()); + } + let parent = path + .parent() + .ok_or_else(|| anyhow!("Failed to write to '{}', No parent path", path.display()))?; + if !parent.exists() { + create_dir_all(parent).with_context(|| { + format!( + "Failed to write to '{}', Cannot create parent directory", + path.display() + ) + })?; + } + Ok(()) } fn read_env_value(key: &str) -> Option> where - T: std::str::FromStr, + T: std::str::FromStr, { - let value = env::var(key).ok()?; - let value = parse_value(&value).ok()?; - Some(value) + let value = env::var(key).ok()?; + let value = parse_value(&value).ok()?; + Some(value) } fn parse_value(value: &str) -> Result> where - T: std::str::FromStr, + T: std::str::FromStr, { - let value = if value == "null" { - None - } else { - let value = match value.parse() { - Ok(value) => value, - Err(_) => bail!("Invalid value '{}'", value), - }; - Some(value) - }; - Ok(value) + let value = if value == "null" { + None + } else { + let value = match value.parse() { + Ok(value) => value, + Err(_) => bail!("Invalid value '{}'", value), + }; + Some(value) + }; + Ok(value) } fn read_env_bool(key: &str) -> Option> { - let value = env::var(key).ok()?; - Some(parse_bool(&value)) + let value = env::var(key).ok()?; + Some(parse_bool(&value)) } fn complete_bool(value: bool) -> Vec { - vec![(!value).to_string()] + vec![(!value).to_string()] } fn complete_option_bool(value: Option) -> Vec { - match value { - Some(true) => vec!["false".to_string(), "null".to_string()], - Some(false) => vec!["true".to_string(), "null".to_string()], - None => vec!["true".to_string(), "false".to_string()], - } + match value { + Some(true) => vec!["false".to_string(), "null".to_string()], + Some(false) => vec!["true".to_string(), "null".to_string()], + None => vec!["true".to_string(), "false".to_string()], + } } fn map_completion_values(value: Vec) -> Vec<(String, Option)> { - value.into_iter().map(|v| (v.to_string(), None)).collect() + value.into_iter().map(|v| (v.to_string(), None)).collect() } fn update_rag(config: &GlobalConfig, f: F) -> Result<()> where - F: FnOnce(&mut Rag) -> Result<()>, + F: FnOnce(&mut Rag) -> Result<()>, { - let mut rag = match config.read().rag.clone() { - Some(v) => v.as_ref().clone(), - None => bail!("No RAG"), - }; - f(&mut rag)?; - config.write().rag = Some(Arc::new(rag)); - Ok(()) + let mut rag = match config.read().rag.clone() { + Some(v) => v.as_ref().clone(), + None => bail!("No RAG"), + }; + f(&mut rag)?; + config.write().rag = Some(Arc::new(rag)); + Ok(()) } fn format_option_value(value: &Option) -> String where - T: std::fmt::Display, + T: std::fmt::Display, { - match value { - Some(value) => value.to_string(), - None => "null".to_string(), - } + match value { + Some(value) => value.to_string(), + None => "null".to_string(), + } } diff --git a/src/function.rs b/src/function.rs index a2d71a3..712dc88 100644 --- a/src/function.rs +++ b/src/function.rs @@ -141,7 +141,7 @@ impl Functions { .extension() .and_then(OsStr::to_str) .map(|s| s.to_lowercase()); - #[cfg_attr(not(unix), expect(dead_code))] + #[cfg_attr(not(unix), expect(dead_code))] let is_script = matches!(file_extension.as_deref(), Some("sh") | Some("py")); if file_path.exists() { diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 05a79d4..30b8521 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -26,293 +26,293 @@ type ConnectedServer = RunningService; #[derive(Debug, Clone, Deserialize)] struct McpServersConfig { - #[serde(rename = "mcpServers")] - mcp_servers: HashMap, + #[serde(rename = "mcpServers")] + mcp_servers: HashMap, } #[derive(Debug, Clone, Deserialize)] struct McpServer { - command: String, - args: Option>, - env: Option>, - cwd: Option, + command: String, + args: Option>, + env: Option>, + cwd: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] enum JsonField { - Str(String), - Bool(bool), - Int(i64), + Str(String), + Bool(bool), + Int(i64), } #[derive(Debug, Clone, Default)] pub struct McpRegistry { - log_path: Option, - config: Option, - servers: HashMap>>, + log_path: Option, + config: Option, + servers: HashMap>>, } impl McpRegistry { - pub async fn init( - log_path: Option, - start_mcp_servers: bool, - use_mcp_servers: Option, - abort_signal: AbortSignal, - config: &Config, - ) -> Result { - let mut registry = Self { - log_path, - ..Default::default() - }; - if !Config::mcp_config_file().try_exists().with_context(|| { - format!( - "Failed to check MCP config file at {}", - Config::mcp_config_file().display() - ) - })? { - debug!( + pub async fn init( + log_path: Option, + start_mcp_servers: bool, + use_mcp_servers: Option, + abort_signal: AbortSignal, + config: &Config, + ) -> Result { + let mut registry = Self { + log_path, + ..Default::default() + }; + if !Config::mcp_config_file().try_exists().with_context(|| { + format!( + "Failed to check MCP config file at {}", + Config::mcp_config_file().display() + ) + })? { + debug!( "MCP config file does not exist at {}, skipping MCP initialization", Config::mcp_config_file().display() ); - return Ok(registry); - } - let err = || { - format!( - "Failed to load MCP config file at {}", - Config::mcp_config_file().display() - ) - }; - let content = tokio::fs::read_to_string(Config::mcp_config_file()) - .await - .with_context(err)?; + return Ok(registry); + } + let err = || { + format!( + "Failed to load MCP config file at {}", + Config::mcp_config_file().display() + ) + }; + let content = tokio::fs::read_to_string(Config::mcp_config_file()) + .await + .with_context(err)?; - if content.trim().is_empty() { - debug!("MCP config file is empty, skipping MCP initialization"); - return Ok(registry); - } + if content.trim().is_empty() { + debug!("MCP config file is empty, skipping MCP initialization"); + return Ok(registry); + } - let (parsed_content, missing_secrets) = interpolate_secrets(&content, &config.vault); + let (parsed_content, missing_secrets) = interpolate_secrets(&content, &config.vault); - if !missing_secrets.is_empty() { - return Err(anyhow!(formatdoc!( + if !missing_secrets.is_empty() { + return Err(anyhow!(formatdoc!( " MCP config file references secrets that are missing from the vault: {:?} Please add these secrets to the vault and try again.", missing_secrets ))); - } + } - let mcp_servers_config: McpServersConfig = - serde_json::from_str(&parsed_content).with_context(err)?; - registry.config = Some(mcp_servers_config); + let mcp_servers_config: McpServersConfig = + serde_json::from_str(&parsed_content).with_context(err)?; + registry.config = Some(mcp_servers_config); - if start_mcp_servers && config.mcp_servers { - abortable_run_with_spinner( - registry.start_select_mcp_servers(use_mcp_servers), - "Loading MCP servers", - abort_signal, - ) - .await?; - } + if start_mcp_servers && config.mcp_servers { + abortable_run_with_spinner( + registry.start_select_mcp_servers(use_mcp_servers), + "Loading MCP servers", + abort_signal, + ) + .await?; + } - Ok(registry) - } + Ok(registry) + } - pub async fn reinit( - registry: McpRegistry, - use_mcp_servers: Option, - abort_signal: AbortSignal, - ) -> Result { - debug!("Reinitializing MCP registry"); - debug!("Stopping all MCP servers"); - let mut new_registry = abortable_run_with_spinner( - registry.stop_all_servers(), - "Stopping MCP servers", - abort_signal.clone(), - ) - .await?; + pub async fn reinit( + registry: McpRegistry, + use_mcp_servers: Option, + abort_signal: AbortSignal, + ) -> Result { + debug!("Reinitializing MCP registry"); + debug!("Stopping all MCP servers"); + let mut new_registry = abortable_run_with_spinner( + registry.stop_all_servers(), + "Stopping MCP servers", + abort_signal.clone(), + ) + .await?; - abortable_run_with_spinner( - new_registry.start_select_mcp_servers(use_mcp_servers), - "Loading MCP servers", - abort_signal, - ) - .await?; + abortable_run_with_spinner( + new_registry.start_select_mcp_servers(use_mcp_servers), + "Loading MCP servers", + abort_signal, + ) + .await?; - Ok(new_registry) - } + Ok(new_registry) + } - async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option) -> Result<()> { - if self.config.is_none() { - debug!("MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization"); - return Ok(()); - } + async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option) -> Result<()> { + if self.config.is_none() { + debug!("MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization"); + return Ok(()); + } - if let Some(servers) = use_mcp_servers { - debug!("Starting selected MCP servers: {:?}", servers); - let config = self - .config - .as_ref() - .with_context(|| "MCP Config not defined. Cannot start servers")?; - let mcp_servers = config.mcp_servers.clone(); + if let Some(servers) = use_mcp_servers { + debug!("Starting selected MCP servers: {:?}", servers); + let config = self + .config + .as_ref() + .with_context(|| "MCP Config not defined. Cannot start servers")?; + let mcp_servers = config.mcp_servers.clone(); - let enabled_servers: HashSet = - servers.split(',').map(|s| s.trim().to_string()).collect(); - let server_ids: Vec = if servers == "all" { - mcp_servers.into_keys().collect() - } else { - mcp_servers - .into_keys() - .filter(|id| enabled_servers.contains(id)) - .collect() - }; + let enabled_servers: HashSet = + servers.split(',').map(|s| s.trim().to_string()).collect(); + let server_ids: Vec = if servers == "all" { + mcp_servers.into_keys().collect() + } else { + mcp_servers + .into_keys() + .filter(|id| enabled_servers.contains(id)) + .collect() + }; - let results: Vec<(String, Arc<_>)> = stream::iter( - server_ids - .into_iter() - .map(|id| async { self.start_server(id).await }), - ) - .buffer_unordered(num_cpus::get()) - .try_collect() - .await?; + let results: Vec<(String, Arc<_>)> = stream::iter( + server_ids + .into_iter() + .map(|id| async { self.start_server(id).await }), + ) + .buffer_unordered(num_cpus::get()) + .try_collect() + .await?; - self.servers = results.into_iter().collect(); - } + self.servers = results.into_iter().collect(); + } - Ok(()) - } + Ok(()) + } - async fn start_server(&self, id: String) -> Result<(String, Arc)> { - let server = self - .config - .as_ref() - .and_then(|c| c.mcp_servers.get(&id)) - .with_context(|| format!("MCP server not found in config: {id}"))?; - let mut cmd = Command::new(&server.command); - if let Some(args) = &server.args { - cmd.args(args); - } - if let Some(env) = &server.env { - let env: HashMap = env - .iter() - .map(|(k, v)| match v { - JsonField::Str(s) => (k.clone(), s.clone()), - JsonField::Bool(b) => (k.clone(), b.to_string()), - JsonField::Int(i) => (k.clone(), i.to_string()), - }) - .collect(); - cmd.envs(env); - } - if let Some(cwd) = &server.cwd { - cmd.current_dir(cwd); - } + async fn start_server(&self, id: String) -> Result<(String, Arc)> { + let server = self + .config + .as_ref() + .and_then(|c| c.mcp_servers.get(&id)) + .with_context(|| format!("MCP server not found in config: {id}"))?; + let mut cmd = Command::new(&server.command); + if let Some(args) = &server.args { + cmd.args(args); + } + if let Some(env) = &server.env { + let env: HashMap = env + .iter() + .map(|(k, v)| match v { + JsonField::Str(s) => (k.clone(), s.clone()), + JsonField::Bool(b) => (k.clone(), b.to_string()), + JsonField::Int(i) => (k.clone(), i.to_string()), + }) + .collect(); + cmd.envs(env); + } + if let Some(cwd) = &server.cwd { + cmd.current_dir(cwd); + } - let transport = if let Some(log_path) = self.log_path.as_ref() { - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); + let transport = if let Some(log_path) = self.log_path.as_ref() { + cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); - let log_file = OpenOptions::new() - .create(true) - .append(true) - .open(log_path)?; - let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?; - transport - } else { - TokioChildProcess::new(cmd)? - }; + let log_file = OpenOptions::new() + .create(true) + .append(true) + .open(log_path)?; + let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?; + transport + } else { + TokioChildProcess::new(cmd)? + }; - let service = Arc::new( - ().serve(transport) - .await - .with_context(|| format!("Failed to start MCP server: {}", &server.command))?, - ); - debug!( + let service = Arc::new( + ().serve(transport) + .await + .with_context(|| format!("Failed to start MCP server: {}", &server.command))?, + ); + debug!( "Available tools for MCP server {id}: {:?}", service.list_tools(None).await? ); - info!("Started MCP server: {id}"); + info!("Started MCP server: {id}"); - Ok((id.to_string(), service)) - } + Ok((id.to_string(), service)) + } - pub async fn stop_all_servers(mut self) -> Result { - for (id, server) in self.servers { - Arc::try_unwrap(server) - .map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? - .cancel() - .await - .with_context(|| format!("Failed to stop MCP server: {id}"))?; - info!("Stopped MCP server: {id}"); - } + pub async fn stop_all_servers(mut self) -> Result { + for (id, server) in self.servers { + Arc::try_unwrap(server) + .map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? + .cancel() + .await + .with_context(|| format!("Failed to stop MCP server: {id}"))?; + info!("Stopped MCP server: {id}"); + } - self.servers = HashMap::new(); + self.servers = HashMap::new(); - Ok(self) - } + Ok(self) + } - pub fn list_started_servers(&self) -> Vec { - self.servers.keys().cloned().collect() - } + pub fn list_started_servers(&self) -> Vec { + self.servers.keys().cloned().collect() + } - pub fn list_configured_servers(&self) -> Vec { - if let Some(config) = &self.config { - config.mcp_servers.keys().cloned().collect() - } else { - vec![] - } - } + pub fn list_configured_servers(&self) -> Vec { + if let Some(config) = &self.config { + config.mcp_servers.keys().cloned().collect() + } else { + vec![] + } + } - pub fn catalog(&self) -> BoxFuture<'static, Result> { - let servers: Vec<(String, Arc)> = self - .servers - .iter() - .map(|(id, s)| (id.clone(), s.clone())) - .collect(); + pub fn catalog(&self) -> BoxFuture<'static, Result> { + let servers: Vec<(String, Arc)> = self + .servers + .iter() + .map(|(id, s)| (id.clone(), s.clone())) + .collect(); - Box::pin(async move { - let mut out = Vec::with_capacity(servers.len()); - for (id, server) in servers { - let tools = server.list_tools(None).await?; - let resources = server.list_resources(None).await.unwrap_or_default(); - // TODO implement prompt sampling for MCP servers - // let prompts = server.service.list_prompts(None).await.unwrap_or_default(); - out.push(json!({ + Box::pin(async move { + let mut out = Vec::with_capacity(servers.len()); + for (id, server) in servers { + let tools = server.list_tools(None).await?; + let resources = server.list_resources(None).await.unwrap_or_default(); + // TODO implement prompt sampling for MCP servers + // let prompts = server.service.list_prompts(None).await.unwrap_or_default(); + out.push(json!({ "server": id, "tools": tools, "resources": resources, })); - } - Ok(Value::Array(out)) - }) - } + } + Ok(Value::Array(out)) + }) + } - pub fn invoke( - &self, - server: &str, - tool: &str, - arguments: Value, - ) -> BoxFuture<'static, Result> { - let server = self - .servers - .get(server) - .cloned() - .with_context(|| format!("Invoked MCP server does not exist: {server}")); + pub fn invoke( + &self, + server: &str, + tool: &str, + arguments: Value, + ) -> BoxFuture<'static, Result> { + let server = self + .servers + .get(server) + .cloned() + .with_context(|| format!("Invoked MCP server does not exist: {server}")); - let tool = tool.to_owned(); - Box::pin(async move { - let server = server?; - let call_tool_request = CallToolRequestParam { - name: Cow::Owned(tool.to_owned()), - arguments: arguments.as_object().cloned(), - }; + let tool = tool.to_owned(); + Box::pin(async move { + let server = server?; + let call_tool_request = CallToolRequestParam { + name: Cow::Owned(tool.to_owned()), + arguments: arguments.as_object().cloned(), + }; - let result = server.call_tool(call_tool_request).await?; - Ok(result) - }) - } + let result = server.call_tool(call_tool_request).await?; + Ok(result) + }) + } - pub fn is_empty(&self) -> bool { - self.servers.is_empty() - } + pub fn is_empty(&self) -> bool { + self.servers.is_empty() + } } diff --git a/src/vault/mod.rs b/src/vault/mod.rs index bc7fb89..003326c 100644 --- a/src/vault/mod.rs +++ b/src/vault/mod.rs @@ -17,118 +17,118 @@ static SECRET_RE: LazyLock = LazyLock::new(|| Regex::new(r"\{\{(.+)}}").u #[derive(Debug, Default, Clone)] pub struct Vault { - local_provider: LocalProvider, + local_provider: LocalProvider, } impl Vault { - pub fn init(config: &Config) -> Self { - let vault_password_file = config.vault_password_file(); - let mut local_provider = LocalProvider { - password_file: Some(vault_password_file), - git_branch: None, - ..LocalProvider::default() - }; + pub fn init(config: &Config) -> Self { + let vault_password_file = config.vault_password_file(); + let mut local_provider = LocalProvider { + password_file: Some(vault_password_file), + git_branch: None, + ..LocalProvider::default() + }; - ensure_password_file_initialized(&mut local_provider) - .expect("Failed to initialize password file"); + ensure_password_file_initialized(&mut local_provider) + .expect("Failed to initialize password file"); - Self { local_provider } - } + Self { local_provider } + } - pub fn add_secret(&self, secret_name: &str) -> Result<()> { - let secret_value = Password::new("Enter the secret value:") - .with_validator(required!()) - .with_display_mode(PasswordDisplayMode::Masked) - .prompt() - .with_context(|| "unable to read secret from input")?; + pub fn add_secret(&self, secret_name: &str) -> Result<()> { + let secret_value = Password::new("Enter the secret value:") + .with_validator(required!()) + .with_display_mode(PasswordDisplayMode::Masked) + .prompt() + .with_context(|| "unable to read secret from input")?; - let h = Handle::current(); - tokio::task::block_in_place(|| { - h.block_on(self.local_provider.set_secret(secret_name, &secret_value)) - })?; - println!("✓ Secret '{secret_name}' added to the vault."); + let h = Handle::current(); + tokio::task::block_in_place(|| { + h.block_on(self.local_provider.set_secret(secret_name, &secret_value)) + })?; + println!("✓ Secret '{secret_name}' added to the vault."); - Ok(()) - } + Ok(()) + } - pub fn get_secret(&self, secret_name: &str, display_output: bool) -> Result { - let h = Handle::current(); - let secret = tokio::task::block_in_place(|| { - h.block_on(self.local_provider.get_secret(secret_name)) - })?; + pub fn get_secret(&self, secret_name: &str, display_output: bool) -> Result { + let h = Handle::current(); + let secret = tokio::task::block_in_place(|| { + h.block_on(self.local_provider.get_secret(secret_name)) + })?; - if display_output { - println!("{}", secret); - } + if display_output { + println!("{}", secret); + } - Ok(secret) - } + Ok(secret) + } - pub fn update_secret(&self, secret_name: &str) -> Result<()> { - let secret_value = Password::new("Enter the secret value:") - .with_validator(required!()) - .with_display_mode(PasswordDisplayMode::Masked) - .prompt() - .with_context(|| "unable to read secret from input")?; - let h = Handle::current(); - tokio::task::block_in_place(|| { - h.block_on( - self.local_provider - .update_secret(secret_name, &secret_value), - ) - })?; - println!("✓ Secret '{secret_name}' updated in the vault."); + pub fn update_secret(&self, secret_name: &str) -> Result<()> { + let secret_value = Password::new("Enter the secret value:") + .with_validator(required!()) + .with_display_mode(PasswordDisplayMode::Masked) + .prompt() + .with_context(|| "unable to read secret from input")?; + let h = Handle::current(); + tokio::task::block_in_place(|| { + h.block_on( + self.local_provider + .update_secret(secret_name, &secret_value), + ) + })?; + println!("✓ Secret '{secret_name}' updated in the vault."); - Ok(()) - } + Ok(()) + } - pub fn delete_secret(&self, secret_name: &str) -> Result<()> { - let h = Handle::current(); - tokio::task::block_in_place(|| h.block_on(self.local_provider.delete_secret(secret_name)))?; - println!("✓ Secret '{secret_name}' deleted from the vault."); + pub fn delete_secret(&self, secret_name: &str) -> Result<()> { + let h = Handle::current(); + tokio::task::block_in_place(|| h.block_on(self.local_provider.delete_secret(secret_name)))?; + println!("✓ Secret '{secret_name}' deleted from the vault."); - Ok(()) - } + Ok(()) + } - pub fn list_secrets(&self, display_output: bool) -> Result> { - let h = Handle::current(); - let secrets = - tokio::task::block_in_place(|| h.block_on(self.local_provider.list_secrets()))?; + pub fn list_secrets(&self, display_output: bool) -> Result> { + let h = Handle::current(); + let secrets = + tokio::task::block_in_place(|| h.block_on(self.local_provider.list_secrets()))?; - if display_output { - if secrets.is_empty() { - println!("The vault is empty."); - } else { - for key in &secrets { - println!("{}", key); - } - } - } + if display_output { + if secrets.is_empty() { + println!("The vault is empty."); + } else { + for key in &secrets { + println!("{}", key); + } + } + } - Ok(secrets) - } + Ok(secrets) + } - pub fn handle_vault_flags(cli: Cli, config: Config) -> Result<()> { - if let Some(secret_name) = cli.add_secret { - config.vault.add_secret(&secret_name)?; - } + pub fn handle_vault_flags(cli: Cli, config: Config) -> Result<()> { + if let Some(secret_name) = cli.add_secret { + config.vault.add_secret(&secret_name)?; + } - if let Some(secret_name) = cli.get_secret { - config.vault.get_secret(&secret_name, true)?; - } + if let Some(secret_name) = cli.get_secret { + config.vault.get_secret(&secret_name, true)?; + } - if let Some(secret_name) = cli.update_secret { - config.vault.update_secret(&secret_name)?; - } + if let Some(secret_name) = cli.update_secret { + config.vault.update_secret(&secret_name)?; + } - if let Some(secret_name) = cli.delete_secret { - config.vault.delete_secret(&secret_name)?; - } + if let Some(secret_name) = cli.delete_secret { + config.vault.delete_secret(&secret_name)?; + } - if cli.list_secrets { - config.vault.list_secrets(true)?; - } + if cli.list_secrets { + config.vault.list_secrets(true)?; + } - Ok(()) - } + Ok(()) + } } diff --git a/src/vault/utils.rs b/src/vault/utils.rs index 0b99f2b..d15bd1b 100644 --- a/src/vault/utils.rs +++ b/src/vault/utils.rs @@ -10,90 +10,90 @@ use std::borrow::Cow; use std::path::PathBuf; pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> Result<()> { - let vault_password_file = local_provider - .password_file - .clone() - .ok_or_else(|| anyhow!("Password file is not configured"))?; + let vault_password_file = local_provider + .password_file + .clone() + .ok_or_else(|| anyhow!("Password file is not configured"))?; - if vault_password_file.exists() { - { - let file_contents = std::fs::read_to_string(&vault_password_file)?; - if !file_contents.trim().is_empty() { - return Ok(()); - } - } + if vault_password_file.exists() { + { + let file_contents = std::fs::read_to_string(&vault_password_file)?; + if !file_contents.trim().is_empty() { + return Ok(()); + } + } - let ans = Confirm::new( - format!( - "The configured password file '{}' is empty. Create a password?", - vault_password_file.display() - ) - .as_str(), - ) - .with_default(true) - .prompt()?; + let ans = Confirm::new( + format!( + "The configured password file '{}' is empty. Create a password?", + vault_password_file.display() + ) + .as_str(), + ) + .with_default(true) + .prompt()?; - if !ans { - return Err(anyhow!("The configured password file '{}' is empty. Please populate it with a password and try again.", vault_password_file.display())); - } + if !ans { + return Err(anyhow!("The configured password file '{}' is empty. Please populate it with a password and try again.", vault_password_file.display())); + } - let password = Password::new("Enter a password to encrypt all vault secrets:") - .with_validator(required!()) - .with_validator(min_length!(10)) - .with_display_mode(PasswordDisplayMode::Masked) - .prompt(); + let password = Password::new("Enter a password to encrypt all vault secrets:") + .with_validator(required!()) + .with_validator(min_length!(10)) + .with_display_mode(PasswordDisplayMode::Masked) + .prompt(); - match password { - Ok(pw) => { - std::fs::write(&vault_password_file, pw.as_bytes())?; - println!( - "✓ Password file '{}' updated.", - vault_password_file.display() - ); - } - Err(_) => { - return Err(anyhow!( + match password { + Ok(pw) => { + std::fs::write(&vault_password_file, pw.as_bytes())?; + println!( + "✓ Password file '{}' updated.", + vault_password_file.display() + ); + } + Err(_) => { + return Err(anyhow!( "Failed to read password from input. Password file not updated." )); - } - } - } else { - let ans = Confirm::new("No password file configured. Do you want to create one now?") - .with_default(true) - .prompt()?; + } + } + } else { + let ans = Confirm::new("No password file configured. Do you want to create one now?") + .with_default(true) + .prompt()?; - if !ans { - return Err(anyhow!("A password file is required to utilize the Loki vault. Please configure a password file in your config file and try again.")); - } + if !ans { + return Err(anyhow!("A password file is required to utilize the Loki vault. Please configure a password file in your config file and try again.")); + } - let password_file: PathBuf = Text::new("Enter the path to the password file to create:") - .with_default(&vault_password_file.display().to_string()) - .with_validator(required!("Password file path is required")) - .with_validator(|input: &str| { - let path = PathBuf::from(input); - if path.exists() { - Ok(Validation::Invalid( - "File already exists. Please choose a different path.".into(), - )) - } else if let Some(parent) = path.parent() { - if !parent.exists() { - Ok(Validation::Invalid( - "Parent directory does not exist.".into(), - )) - } else { - Ok(Validation::Valid) - } - } else { - Ok(Validation::Valid) - } - }) - .prompt()? - .into(); + let password_file: PathBuf = Text::new("Enter the path to the password file to create:") + .with_default(&vault_password_file.display().to_string()) + .with_validator(required!("Password file path is required")) + .with_validator(|input: &str| { + let path = PathBuf::from(input); + if path.exists() { + Ok(Validation::Invalid( + "File already exists. Please choose a different path.".into(), + )) + } else if let Some(parent) = path.parent() { + if !parent.exists() { + Ok(Validation::Invalid( + "Parent directory does not exist.".into(), + )) + } else { + Ok(Validation::Valid) + } + } else { + Ok(Validation::Valid) + } + }) + .prompt()? + .into(); - if password_file != vault_password_file { - println!( - "{}", - formatdoc!( + if password_file != vault_password_file { + println!( + "{}", + formatdoc!( " Note: The default password file path is '{}'. You have chosen to create a different path: '{}'. @@ -102,49 +102,49 @@ pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> R vault_password_file.display(), password_file.display() ) - ); - } + ); + } - ensure_parent_exists(&password_file)?; + ensure_parent_exists(&password_file)?; - let password = Password::new("Enter a password to encrypt all vault secrets:") - .with_display_mode(PasswordDisplayMode::Masked) - .with_validator(required!()) - .with_validator(min_length!(10)) - .prompt(); + let password = Password::new("Enter a password to encrypt all vault secrets:") + .with_display_mode(PasswordDisplayMode::Masked) + .with_validator(required!()) + .with_validator(min_length!(10)) + .prompt(); - match password { - Ok(pw) => { - std::fs::write(&password_file, pw.as_bytes())?; - local_provider.password_file = Some(password_file); - println!( - "✓ Password file '{}' created.", - vault_password_file.display() - ); - } - Err(_) => { - return Err(anyhow!( + match password { + Ok(pw) => { + std::fs::write(&password_file, pw.as_bytes())?; + local_provider.password_file = Some(password_file); + println!( + "✓ Password file '{}' created.", + vault_password_file.display() + ); + } + Err(_) => { + return Err(anyhow!( "Failed to read password from input. Password file not created." )); - } - } - } + } + } + } - Ok(()) + Ok(()) } pub fn interpolate_secrets<'a>(content: &'a str, vault: &Vault) -> (Cow<'a, str>, Vec) { - let mut missing_secrets = vec![]; - let parsed_content = SECRET_RE.replace_all(content, |caps: &fancy_regex::Captures<'_>| { - let secret = vault.get_secret(caps[1].trim(), false); - match secret { - Ok(s) => s, - Err(_) => { - missing_secrets.push(caps[1].to_string()); - "".to_string() - } - } - }); + let mut missing_secrets = vec![]; + let parsed_content = SECRET_RE.replace_all(content, |caps: &fancy_regex::Captures<'_>| { + let secret = vault.get_secret(caps[1].trim(), false); + match secret { + Ok(s) => s, + Err(_) => { + missing_secrets.push(caps[1].to_string()); + "".to_string() + } + } + }); - (parsed_content, missing_secrets) + (parsed_content, missing_secrets) }