#[macro_export] macro_rules! register_client { ( $(($module:ident, $name:literal, $config:ident, $client:ident),)+ ) => { $( mod $module; )+ $( use self::$module::$config; )+ #[derive(Debug, Clone, serde::Deserialize)] #[serde(tag = "type")] pub enum ClientConfig { $( #[serde(rename = $name)] $config($config), )+ #[serde(other)] Unknown, } $( #[derive(Debug)] pub struct $client { global_config: $crate::config::GlobalConfig, config: $config, model: $crate::client::Model, } impl $client { pub const NAME: &'static str = $name; pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option> { let config = global_config.read().clients.iter().find_map(|client_config| { if let ClientConfig::$config(c) = client_config { if Self::name(c) == model.client_name() { return Some(c.clone()) } } None })?; Some(Box::new(Self { global_config: global_config.clone(), config, model: model.clone(), })) } pub fn list_models(local_config: &$config) -> Vec { let client_name = Self::name(local_config); if local_config.models.is_empty() { if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| { v.provider == $name || ($name == OpenAICompatibleClient::NAME && local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default()) }) { return Model::from_config(client_name, &v.models); } vec![] } else { Model::from_config(client_name, &local_config.models) } } pub fn name(local_config: &$config) -> &str { local_config.name.as_deref().unwrap_or(Self::NAME) } } )+ pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result> { let model = model.unwrap_or_else(|| config.read().model.clone()); None $(.or_else(|| $client::init(config, &model)))+ .ok_or_else(|| { anyhow::anyhow!("Invalid model '{}'", model.id()) }) } pub fn list_client_types() -> Vec<&'static str> { let mut client_types: Vec<_> = vec![$($client::NAME,)+]; client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name)); client_types } pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> { $( if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME { return $client::create_client_config(vault).await } )+ if let Some(ret) = create_openai_compatible_client_config(client).await? { return Ok(ret); } anyhow::bail!("Unknown client '{}'", client) } static ALL_CLIENT_NAMES: std::sync::OnceLock> = std::sync::OnceLock::new(); pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> { let names = ALL_CLIENT_NAMES.get_or_init(|| { config .clients .iter() .flat_map(|v| match v { $(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+ ClientConfig::Unknown => vec![], }) .collect() }); names.iter().collect() } static ALL_MODELS: std::sync::OnceLock> = std::sync::OnceLock::new(); pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { let models = ALL_MODELS.get_or_init(|| { config .clients .iter() .flat_map(|v| match v { $(ClientConfig::$config(c) => $client::list_models(c),)+ ClientConfig::Unknown => vec![], }) .collect() }); models.iter().collect() } pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> { list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect() } }; } #[macro_export] macro_rules! client_common_fns { () => { fn global_config(&self) -> &$crate::config::GlobalConfig { &self.global_config } fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> { self.config.extra.as_ref() } fn patch_config(&self) -> Option<&$crate::client::RequestPatch> { self.config.patch.as_ref() } fn name(&self) -> &str { Self::name(&self.config) } fn model(&self) -> &Model { &self.model } }; } #[macro_export] macro_rules! impl_client_trait { ( $client:ident, ($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path), ($prepare_embeddings:path, $embeddings:path), ($prepare_rerank:path, $rerank:path), ) => { #[async_trait::async_trait] impl $crate::client::Client for $crate::client::$client { client_common_fns!(); async fn chat_completions_inner( &self, client: &reqwest::Client, data: $crate::client::ChatCompletionsData, ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { let request_data = $prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data); $chat_completions(builder, self.model()).await } async fn chat_completions_streaming_inner( &self, client: &reqwest::Client, handler: &mut $crate::client::SseHandler, data: $crate::client::ChatCompletionsData, ) -> Result<()> { let request_data = $prepare_chat_completions(self, data)?; let builder = self.request_builder(client, request_data); $chat_completions_streaming(builder, handler, self.model()).await } async fn embeddings_inner( &self, client: &reqwest::Client, data: &$crate::client::EmbeddingsData, ) -> Result<$crate::client::EmbeddingsOutput> { let request_data = $prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data); $embeddings(builder, self.model()).await } async fn rerank_inner( &self, client: &reqwest::Client, data: &$crate::client::RerankData, ) -> Result<$crate::client::RerankOutput> { let request_data = $prepare_rerank(self, data)?; let builder = self.request_builder(client, request_data); $rerank(builder, self.model()).await } } }; } #[macro_export] macro_rules! create_client_config { ($prompts:expr) => { pub async fn create_client_config( vault: &$crate::vault::Vault, ) -> anyhow::Result<(String, serde_json::Value)> { $crate::client::create_config(&$prompts, Self::NAME, vault).await } }; } #[macro_export] macro_rules! create_oauth_supported_client_config { () => { pub async fn create_client_config(vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> { let mut config = serde_json::json!({ "type": Self::NAME }); let auth_method = inquire::Select::new( "Authentication method:", vec!["API Key", "OAuth"], ) .prompt()?; if auth_method == "API Key" { let env_name = format!("{}_API_KEY", Self::NAME).to_ascii_uppercase(); vault.add_secret(&env_name)?; config["api_key"] = format!("{{{{{env_name}}}}}").into(); } else { config["auth"] = "oauth".into(); } let model = $crate::client::set_client_models_config(&mut config, Self::NAME).await?; let clients = json!(vec![config]); Ok((model, clients)) } } } #[macro_export] macro_rules! config_get_fn { ($field_name:ident, $fn_name:ident) => { fn $fn_name(&self) -> anyhow::Result { let env_prefix = Self::name(&self.config); let env_name = format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase(); std::env::var(&env_name) .ok() .or_else(|| self.config.$field_name.clone()) .ok_or_else(|| anyhow::anyhow!("Missing '{}'", stringify!($field_name))) } }; } #[macro_export] macro_rules! unsupported_model { ($name:expr) => { anyhow::bail!("Unsupported model '{}'", $name) }; }