diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index a46b818..729fcc0 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -24,8 +24,9 @@ impl AzureOpenAIClient { "api_base", "API Base", Some("e.g. https://{RESOURCE}.openai.azure.com"), + false ), - ("api_key", "API Key", None), + ("api_key", "API Key", None, true), ]; } diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 88d57d9..cedd3a2 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -33,9 +33,9 @@ impl BedrockClient { config_get_fn!(session_token, get_session_token); pub const PROMPTS: [PromptAction<'static>; 3] = [ - ("access_key_id", "AWS Access Key ID", None), - ("secret_access_key", "AWS Secret Access Key", None), - ("region", "AWS Region", None), + ("access_key_id", "AWS Access Key ID", None, true), + ("secret_access_key", "AWS Secret Access Key", None, true), + ("region", "AWS Region", None, false), ]; fn chat_completions_builder( diff --git a/src/client/claude.rs b/src/client/claude.rs index 28f8225..12e2559 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -24,7 +24,7 @@ impl ClaudeClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; } impl_client_trait!( diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 9dc6da9..502b29b 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -24,7 +24,7 @@ impl CohereClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; } impl_client_trait!( diff --git a/src/client/common.rs b/src/client/common.rs index c7ac96f..e903815 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -7,6 +7,7 @@ use crate::{ utils::*, }; +use crate::vault::Vault; use anyhow::{bail, Context, Result}; use fancy_regex::Regex; use indexmap::IndexMap; @@ -343,19 +344,25 @@ pub struct RerankResult { pub relevance_score: f64, } -pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>); +pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool); pub async fn create_config( prompts: &[PromptAction<'static>], client: &str, + vault: &Vault, ) -> Result<(String, Value)> { let mut config = json!({ "type": client, }); - for (key, desc, help_message) in prompts { + for (key, desc, help_message, is_secret) in prompts { let env_name = format!("{client}_{key}").to_ascii_uppercase(); let required = std::env::var(&env_name).is_err(); - let value = prompt_input_string(desc, required, *help_message)?; + let value = if !is_secret { + prompt_input_string(desc, required, *help_message)? + } else { + vault.add_secret(&env_name)?; + format!("{{{{{}}}}}", env_name) + }; if !value.is_empty() { config[key] = value.into(); } diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 3c8778d..52b8e45 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -23,7 +23,7 @@ impl GeminiClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; } impl_client_trait!( diff --git a/src/client/macros.rs b/src/client/macros.rs index a3730bd..0b76974 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -87,10 +87,10 @@ macro_rules! register_client { client_types } - pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> { + pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> { $( if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME { - return create_config(&$client::PROMPTS, $client::NAME).await + return create_config(&$client::PROMPTS, $client::NAME, vault).await } )+ if let Some(ret) = create_openai_compatible_client_config(client).await? { diff --git a/src/client/openai.rs b/src/client/openai.rs index 59496ef..b2432d4 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -25,7 +25,7 @@ impl OpenAIClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)]; + pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; } impl_client_trait!( diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 16b55d1..43e15bc 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -27,8 +27,8 @@ impl VertexAIClient { config_get_fn!(location, get_location); pub const PROMPTS: [PromptAction<'static>; 2] = [ - ("project_id", "Project ID", None), - ("location", "Location", None), + ("project_id", "Project ID", None, false), + ("location", "Location", None, false), ]; } diff --git a/src/config/mod.rs b/src/config/mod.rs index 7939242..2e9fc01 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -24,7 +24,7 @@ use crate::utils::*; use crate::mcp::{ McpRegistry, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX, }; -use crate::vault::{interpolate_secrets, Vault}; +use crate::vault::{create_vault_password_file, interpolate_secrets, Vault}; use anyhow::{anyhow, bail, Context, Result}; use fancy_regex::Regex; use indexmap::IndexMap; @@ -3135,11 +3135,15 @@ async fn create_config_file(config_path: &Path) -> Result<()> { process::exit(0); } + let mut vault = Vault::init_bare(); + create_vault_password_file(&mut vault)?; + let client = Select::new("API Provider (required):", list_client_types()).prompt()?; let mut config = json!({}); - let (model, clients_config) = create_client_config(client).await?; + let (model, clients_config) = create_client_config(client, &vault).await?; config["model"] = model.into(); + config["vault_password_file"] = vault.password_file()?.display().to_string().into(); config[CLIENTS_FIELD] = clients_config; let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?; diff --git a/src/vault/mod.rs b/src/vault/mod.rs index 003326c..df4e9ad 100644 --- a/src/vault/mod.rs +++ b/src/vault/mod.rs @@ -1,5 +1,7 @@ mod utils; +use std::path::PathBuf; +pub use utils::create_vault_password_file; pub use utils::interpolate_secrets; use crate::cli::Cli; @@ -21,6 +23,17 @@ pub struct Vault { } impl Vault { + pub fn init_bare() -> Self { + let vault_password_file = Config::default().vault_password_file(); + let local_provider = LocalProvider { + password_file: Some(vault_password_file), + git_branch: None, + ..LocalProvider::default() + }; + + Self { local_provider } + } + pub fn init(config: &Config) -> Self { let vault_password_file = config.vault_password_file(); let mut local_provider = LocalProvider { @@ -35,6 +48,13 @@ impl Vault { Self { local_provider } } + pub fn password_file(&self) -> Result { + self.local_provider + .password_file + .clone() + .with_context(|| "A password file is required for the local provider") + } + pub fn add_secret(&self, secret_name: &str) -> Result<()> { let secret_value = Password::new("Enter the secret value:") .with_validator(required!()) diff --git a/src/vault/utils.rs b/src/vault/utils.rs index d15bd1b..d6da91c 100644 --- a/src/vault/utils.rs +++ b/src/vault/utils.rs @@ -19,6 +19,28 @@ pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> R { let file_contents = std::fs::read_to_string(&vault_password_file)?; if !file_contents.trim().is_empty() { + Ok(()) + } else { + Err(anyhow!("The configured password file '{}' is empty. Please populate it with a password and try again.", vault_password_file.display())) + } + } + } else { + Err(anyhow!("A password file is required to utilize the Loki vault. Please configure a password file in your config file and try again.")) + } +} + +pub fn create_vault_password_file(vault: &mut Vault) -> Result<()> { + let vault_password_file = vault + .local_provider + .password_file + .clone() + .ok_or_else(|| anyhow!("Password file is not configured"))?; + + if vault_password_file.exists() { + { + let file_contents = std::fs::read_to_string(&vault_password_file)?; + if !file_contents.trim().is_empty() { + debug!("create_vault_password_file was called but the password file already exists and is non-empty"); return Ok(()); } } @@ -91,13 +113,12 @@ pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> R .into(); if password_file != vault_password_file { - println!( + debug!( "{}", formatdoc!( " - Note: The default password file path is '{}'. - You have chosen to create a different path: '{}'. - Please ensure your configuration is updated accordingly. + The default password file path is '{}'. + User chose to create file at a different path: '{}'. ", vault_password_file.display(), password_file.display() @@ -116,7 +137,7 @@ pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> R match password { Ok(pw) => { std::fs::write(&password_file, pw.as_bytes())?; - local_provider.password_file = Some(password_file); + vault.local_provider.password_file = Some(password_file); println!( "✓ Password file '{}' created.", vault_password_file.display()