diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 729fcc0..e64277f 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -18,16 +18,16 @@ pub struct AzureOpenAIConfig { impl AzureOpenAIClient { config_get_fn!(api_base, get_api_base); config_get_fn!(api_key, get_api_key); - - pub const PROMPTS: [PromptAction<'static>; 2] = [ + + create_client_config!([ ( "api_base", "API Base", Some("e.g. https://{RESOURCE}.openai.azure.com"), - false + false, ), ("api_key", "API Key", None, true), - ]; + ]); } impl_client_trait!( diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 107b5f5..b521d25 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -32,11 +32,11 @@ impl BedrockClient { config_get_fn!(region, get_region); config_get_fn!(session_token, get_session_token); - pub const PROMPTS: [PromptAction<'static>; 3] = [ + create_client_config!([ ("access_key_id", "AWS Access Key ID", None, true), ("secret_access_key", "AWS Secret Access Key", None, true), ("region", "AWS Region", None, false), - ]; + ]); fn chat_completions_builder( &self, diff --git a/src/client/claude.rs b/src/client/claude.rs index aa526a3..283be26 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -27,8 +27,8 @@ pub struct ClaudeConfig { impl ClaudeClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; + + create_oauth_supported_client_config!(); } #[async_trait::async_trait] diff --git a/src/client/cohere.rs b/src/client/cohere.rs index d132714..f34e719 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -24,7 +24,7 @@ impl CohereClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; + create_client_config!([("api_key", "API Key", None, true)]); } impl_client_trait!( diff --git a/src/client/common.rs b/src/client/common.rs index f5c8c2d..c332dd9 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -546,7 +546,7 @@ pub fn json_str_from_map<'a>( map.get(field_name).and_then(|v| v.as_str()) } -async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result { +pub async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result { if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) { let models: Vec = provider .models diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 52b8e45..c0bafaa 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -4,7 +4,7 @@ use super::*; use anyhow::{Context, Result}; use reqwest::RequestBuilder; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta"; @@ -23,7 +23,7 @@ impl GeminiClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; + create_client_config!([("api_key", "API Key", None, true)]); } impl_client_trait!( diff --git a/src/client/macros.rs b/src/client/macros.rs index a1f5695..0c8ac9d 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -90,7 +90,7 @@ macro_rules! register_client { pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> { $( if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME { - return create_config(&$client::PROMPTS, $client::NAME, vault).await + return $client::create_client_config(vault).await } )+ if let Some(ret) = create_openai_compatible_client_config(client).await? { @@ -218,6 +218,44 @@ macro_rules! impl_client_trait { }; } +#[macro_export] +macro_rules! create_client_config { + ($prompts:expr) => { + pub async fn create_client_config( + vault: &$crate::vault::Vault, + ) -> anyhow::Result<(String, serde_json::Value)> { + $crate::client::create_config(&$prompts, Self::NAME, vault).await + } + }; +} + +#[macro_export] +macro_rules! create_oauth_supported_client_config { + () => { + pub async fn create_client_config(vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> { + let mut config = serde_json::json!({ "type": Self::NAME }); + + let auth_method = inquire::Select::new( + "Authentication method:", + vec!["API Key", "OAuth"], + ) + .prompt()?; + + if auth_method == "API Key" { + let env_name = format!("{}_API_KEY", Self::NAME).to_ascii_uppercase(); + vault.add_secret(&env_name)?; + config["api_key"] = format!("{{{{{env_name}}}}}").into(); + } else { + config["auth"] = "oauth".into(); + } + + let model = $crate::client::set_client_models_config(&mut config, Self::NAME).await?; + let clients = json!(vec![config]); + Ok((model, clients)) + } + } +} + #[macro_export] macro_rules! config_get_fn { ($field_name:ident, $fn_name:ident) => { diff --git a/src/client/openai.rs b/src/client/openai.rs index 506f593..f5a4e2c 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -2,10 +2,10 @@ use super::*; use crate::utils::strip_think_tag; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result, bail}; use reqwest::RequestBuilder; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; const API_BASE: &str = "https://api.openai.com/v1"; @@ -25,7 +25,7 @@ impl OpenAIClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; + create_client_config!([("api_key", "API Key", None, true)]); } impl_client_trait!( @@ -114,7 +114,9 @@ pub async fn openai_chat_completions_streaming( function_arguments = String::from("{}"); } let arguments: Value = function_arguments.parse().with_context(|| { - format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'") + format!( + "Tool call '{function_name}' has non-JSON arguments '{function_arguments}'" + ) })?; handler.tool_call(ToolCall::new( function_name.clone(), diff --git a/src/client/openai_compatible.rs b/src/client/openai_compatible.rs index 3ee28ba..ecaf2b5 100644 --- a/src/client/openai_compatible.rs +++ b/src/client/openai_compatible.rs @@ -21,7 +21,7 @@ impl OpenAICompatibleClient { config_get_fn!(api_base, get_api_base); config_get_fn!(api_key, get_api_key); - pub const PROMPTS: [PromptAction<'static>; 0] = []; + create_client_config!([]); } impl_client_trait!( diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index dd29072..d4be930 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -3,11 +3,11 @@ use super::claude::*; use super::openai::*; use super::*; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{Context, Result, anyhow, bail}; use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::{path::PathBuf, str::FromStr}; #[derive(Debug, Clone, Deserialize, Default)] @@ -26,10 +26,10 @@ impl VertexAIClient { config_get_fn!(project_id, get_project_id); config_get_fn!(location, get_location); - pub const PROMPTS: [PromptAction<'static>; 2] = [ + create_client_config!([ ("project_id", "Project ID", None, false), ("location", "Location", None, false), - ]; + ]); } #[async_trait::async_trait] @@ -99,9 +99,13 @@ fn prepare_chat_completions( let access_token = get_access_token(self_.name())?; let base_url = if location == "global" { - format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers") + format!( + "https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers" + ) } else { - format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers") + format!( + "https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers" + ) }; let model_name = self_.model.real_name(); @@ -158,9 +162,13 @@ fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result Result