Files
loki/src/client/macros.rs

280 lines
9.9 KiB
Rust

#[macro_export]
macro_rules! register_client {
(
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
) => {
$(
mod $module;
)+
$(
use self::$module::$config;
)+
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
$(
#[serde(rename = $name)]
$config($config),
)+
#[serde(other)]
Unknown,
}
$(
#[derive(Debug)]
pub struct $client {
global_config: $crate::config::GlobalConfig,
config: $config,
model: $crate::client::Model,
}
impl $client {
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == model.client_name() {
return Some(c.clone())
}
}
None
})?;
Some(Box::new(Self {
global_config: global_config.clone(),
config,
model: model.clone(),
}))
}
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| {
v.provider == $name ||
($name == OpenAICompatibleClient::NAME
&& local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default())
}) {
return Model::from_config(client_name, &v.models);
}
vec![]
} else {
Model::from_config(client_name, &local_config.models)
}
}
pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
}
)+
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
None
$(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Invalid model '{}'", model.id())
})
}
pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name));
client_types
}
pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
return $client::create_client_config(vault).await
}
)+
if let Some(ret) = create_openai_compatible_client_config(client).await? {
return Ok(ret);
}
anyhow::bail!("Unknown client '{}'", client)
}
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
let names = ALL_CLIENT_NAMES.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
ClientConfig::Unknown => vec![],
})
.collect()
});
names.iter().collect()
}
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
let models = ALL_MODELS.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
});
models.iter().collect()
}
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
}
};
}
#[macro_export]
macro_rules! client_common_fns {
() => {
fn global_config(&self) -> &$crate::config::GlobalConfig {
&self.global_config
}
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
self.config.extra.as_ref()
}
fn patch_config(&self) -> Option<&$crate::client::RequestPatch> {
self.config.patch.as_ref()
}
fn name(&self) -> &str {
Self::name(&self.config)
}
fn model(&self) -> &Model {
&self.model
}
};
}
#[macro_export]
macro_rules! impl_client_trait {
(
$client:ident,
($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path),
($prepare_embeddings:path, $embeddings:path),
($prepare_rerank:path, $rerank:path),
) => {
#[async_trait::async_trait]
impl $crate::client::Client for $crate::client::$client {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &reqwest::Client,
data: $crate::client::ChatCompletionsData,
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::client::SseHandler,
data: $crate::client::ChatCompletionsData,
) -> Result<()> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::EmbeddingsData,
) -> Result<$crate::client::EmbeddingsOutput> {
let request_data = $prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data);
$embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let request_data = $prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
$rerank(builder, self.model()).await
}
}
};
}
#[macro_export]
macro_rules! create_client_config {
($prompts:expr) => {
pub async fn create_client_config(
vault: &$crate::vault::Vault,
) -> anyhow::Result<(String, serde_json::Value)> {
$crate::client::create_config(&$prompts, Self::NAME, vault).await
}
};
}
#[macro_export]
macro_rules! create_oauth_supported_client_config {
() => {
pub async fn create_client_config(vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
let mut config = serde_json::json!({ "type": Self::NAME });
let auth_method = inquire::Select::new(
"Authentication method:",
vec!["API Key", "OAuth"],
)
.prompt()?;
if auth_method == "API Key" {
let env_name = format!("{}_API_KEY", Self::NAME).to_ascii_uppercase();
vault.add_secret(&env_name)?;
config["api_key"] = format!("{{{{{env_name}}}}}").into();
} else {
config["auth"] = "oauth".into();
}
let model = $crate::client::set_client_models_config(&mut config, Self::NAME).await?;
let clients = json!(vec![config]);
Ok((model, clients))
}
}
}
#[macro_export]
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> anyhow::Result<String> {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
std::env::var(&env_name)
.ok()
.or_else(|| self.config.$field_name.clone())
.ok_or_else(|| anyhow::anyhow!("Missing '{}'", stringify!($field_name)))
}
};
}
#[macro_export]
macro_rules! unsupported_model {
($name:expr) => {
anyhow::bail!("Unsupported model '{}'", $name)
};
}