684 lines
20 KiB
Rust
684 lines
20 KiB
Rust
use super::*;
|
|
|
|
use crate::{
|
|
config::{Config, GlobalConfig, Input},
|
|
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
|
|
render::render_stream,
|
|
utils::*,
|
|
};
|
|
|
|
use crate::vault::Vault;
|
|
use anyhow::{Context, Result, bail};
|
|
use fancy_regex::Regex;
|
|
use indexmap::IndexMap;
|
|
use inquire::{
|
|
MultiSelect, Select, Text, list_option::ListOption, required, validator::Validation,
|
|
};
|
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
|
use serde::Deserialize;
|
|
use serde_json::{Value, json};
|
|
use std::sync::LazyLock;
|
|
use std::time::Duration;
|
|
use tokio::sync::mpsc::unbounded_channel;
|
|
|
|
pub const MODELS_YAML: &str = include_str!("../../models.yaml");
|
|
|
|
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
|
Config::local_models_override()
|
|
.ok()
|
|
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
|
|
});
|
|
|
|
static EMBEDDING_MODEL_RE: LazyLock<Regex> = LazyLock::new(|| {
|
|
Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap()
|
|
});
|
|
|
|
static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/").unwrap());
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait Client: Sync + Send {
|
|
fn global_config(&self) -> &GlobalConfig;
|
|
|
|
fn extra_config(&self) -> Option<&ExtraConfig>;
|
|
|
|
fn patch_config(&self) -> Option<&RequestPatch>;
|
|
|
|
fn name(&self) -> &str;
|
|
|
|
fn model(&self) -> &Model;
|
|
|
|
fn supports_oauth(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
fn build_client(&self) -> Result<ReqwestClient> {
|
|
let mut builder = ReqwestClient::builder();
|
|
let extra = self.extra_config();
|
|
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
|
|
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
|
|
builder = set_proxy(builder, proxy)?;
|
|
}
|
|
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
|
|
builder = builder.user_agent(user_agent);
|
|
}
|
|
let client = builder
|
|
.connect_timeout(Duration::from_secs(timeout))
|
|
.build()
|
|
.with_context(|| "Failed to build client")?;
|
|
Ok(client)
|
|
}
|
|
|
|
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
|
|
if self.global_config().read().dry_run {
|
|
let content = input.echo_messages();
|
|
return Ok(ChatCompletionsOutput::new(&content));
|
|
}
|
|
let client = self.build_client()?;
|
|
let data = input.prepare_completion_data(self.model(), false)?;
|
|
self.chat_completions_inner(&client, data)
|
|
.await
|
|
.with_context(|| "Failed to call chat-completions api")
|
|
}
|
|
|
|
async fn chat_completions_streaming(
|
|
&self,
|
|
input: &Input,
|
|
handler: &mut SseHandler,
|
|
) -> Result<()> {
|
|
let abort_signal = handler.abort();
|
|
let input = input.clone();
|
|
tokio::select! {
|
|
ret = async {
|
|
if self.global_config().read().dry_run {
|
|
let content = input.echo_messages();
|
|
handler.text(&content)?;
|
|
return Ok(());
|
|
}
|
|
let client = self.build_client()?;
|
|
let data = input.prepare_completion_data(self.model(), true)?;
|
|
self.chat_completions_streaming_inner(&client, handler, data).await
|
|
} => {
|
|
handler.done();
|
|
ret.with_context(|| "Failed to call chat-completions api")
|
|
}
|
|
_ = wait_abort_signal(&abort_signal) => {
|
|
handler.done();
|
|
Ok(())
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
|
|
let client = self.build_client()?;
|
|
self.embeddings_inner(&client, data)
|
|
.await
|
|
.context("Failed to call embeddings api")
|
|
}
|
|
|
|
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
|
|
let client = self.build_client()?;
|
|
self.rerank_inner(&client, data)
|
|
.await
|
|
.context("Failed to call rerank api")
|
|
}
|
|
|
|
async fn chat_completions_inner(
|
|
&self,
|
|
client: &ReqwestClient,
|
|
data: ChatCompletionsData,
|
|
) -> Result<ChatCompletionsOutput>;
|
|
|
|
async fn chat_completions_streaming_inner(
|
|
&self,
|
|
client: &ReqwestClient,
|
|
handler: &mut SseHandler,
|
|
data: ChatCompletionsData,
|
|
) -> Result<()>;
|
|
|
|
async fn embeddings_inner(
|
|
&self,
|
|
_client: &ReqwestClient,
|
|
_data: &EmbeddingsData,
|
|
) -> Result<EmbeddingsOutput> {
|
|
bail!("The client doesn't support embeddings api")
|
|
}
|
|
|
|
async fn rerank_inner(
|
|
&self,
|
|
_client: &ReqwestClient,
|
|
_data: &RerankData,
|
|
) -> Result<RerankOutput> {
|
|
bail!("The client doesn't support rerank api")
|
|
}
|
|
|
|
fn request_builder(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
mut request_data: RequestData,
|
|
) -> RequestBuilder {
|
|
self.patch_request_data(&mut request_data);
|
|
request_data.into_builder(client)
|
|
}
|
|
|
|
fn patch_request_data(&self, request_data: &mut RequestData) {
|
|
let model_type = self.model().model_type();
|
|
if let Some(patch) = self.model().patch() {
|
|
request_data.apply_patch(patch.clone());
|
|
}
|
|
|
|
let patch_map = std::env::var(get_env_name(&format!(
|
|
"patch_{}_{}",
|
|
self.model().client_name(),
|
|
model_type.api_name(),
|
|
)))
|
|
.ok()
|
|
.and_then(|v| serde_json::from_str(&v).ok())
|
|
.or_else(|| {
|
|
self.patch_config()
|
|
.and_then(|v| model_type.extract_patch(v))
|
|
.cloned()
|
|
});
|
|
let patch_map = match patch_map {
|
|
Some(v) => v,
|
|
_ => return,
|
|
};
|
|
for (key, patch) in patch_map {
|
|
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
|
|
if let Ok(regex) = Regex::new(&format!("^({key})$"))
|
|
&& let Ok(true) = regex.is_match(self.model().name())
|
|
{
|
|
request_data.apply_patch(patch);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ClientConfig {
|
|
fn default() -> Self {
|
|
Self::OpenAIConfig(OpenAIConfig::default())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
pub struct ExtraConfig {
|
|
pub proxy: Option<String>,
|
|
pub connect_timeout: Option<u64>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Default)]
|
|
pub struct RequestPatch {
|
|
pub chat_completions: Option<ApiPatch>,
|
|
pub embeddings: Option<ApiPatch>,
|
|
pub rerank: Option<ApiPatch>,
|
|
}
|
|
|
|
pub type ApiPatch = IndexMap<String, Value>;
|
|
|
|
pub struct RequestData {
|
|
pub url: String,
|
|
pub headers: IndexMap<String, String>,
|
|
pub body: Value,
|
|
}
|
|
|
|
impl RequestData {
|
|
pub fn new<T>(url: T, body: Value) -> Self
|
|
where
|
|
T: std::fmt::Display,
|
|
{
|
|
Self {
|
|
url: url.to_string(),
|
|
headers: Default::default(),
|
|
body,
|
|
}
|
|
}
|
|
|
|
pub fn bearer_auth<T>(&mut self, auth: T)
|
|
where
|
|
T: std::fmt::Display,
|
|
{
|
|
self.headers
|
|
.insert("authorization".into(), format!("Bearer {auth}"));
|
|
}
|
|
|
|
pub fn header<K, V>(&mut self, key: K, value: V)
|
|
where
|
|
K: std::fmt::Display,
|
|
V: std::fmt::Display,
|
|
{
|
|
self.headers.insert(key.to_string(), value.to_string());
|
|
}
|
|
|
|
pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder {
|
|
let RequestData { url, headers, body } = self;
|
|
debug!("Request {url} {body}");
|
|
|
|
let mut builder = client.post(url);
|
|
for (key, value) in headers {
|
|
builder = builder.header(key, value);
|
|
}
|
|
builder = builder.json(&body);
|
|
builder
|
|
}
|
|
|
|
pub fn apply_patch(&mut self, patch: Value) {
|
|
if let Some(patch_url) = patch["url"].as_str() {
|
|
self.url = patch_url.into();
|
|
}
|
|
if let Some(patch_body) = patch.get("body") {
|
|
json_patch::merge(&mut self.body, patch_body)
|
|
}
|
|
if let Some(patch_headers) = patch["headers"].as_object() {
|
|
for (key, value) in patch_headers {
|
|
if let Some(value) = value.as_str() {
|
|
self.header(key, value)
|
|
} else if value.is_null() {
|
|
self.headers.swap_remove(key);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct ChatCompletionsData {
|
|
pub messages: Vec<Message>,
|
|
pub temperature: Option<f64>,
|
|
pub top_p: Option<f64>,
|
|
pub functions: Option<Vec<FunctionDeclaration>>,
|
|
pub stream: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct ChatCompletionsOutput {
|
|
pub text: String,
|
|
pub tool_calls: Vec<ToolCall>,
|
|
}
|
|
|
|
impl ChatCompletionsOutput {
|
|
pub fn new(text: &str) -> Self {
|
|
Self {
|
|
text: text.to_string(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct EmbeddingsData {
|
|
pub texts: Vec<String>,
|
|
pub query: bool,
|
|
}
|
|
|
|
impl EmbeddingsData {
|
|
pub fn new(texts: Vec<String>, query: bool) -> Self {
|
|
Self { texts, query }
|
|
}
|
|
}
|
|
|
|
pub type EmbeddingsOutput = Vec<Vec<f32>>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct RerankData {
|
|
pub query: String,
|
|
pub documents: Vec<String>,
|
|
pub top_n: usize,
|
|
}
|
|
|
|
impl RerankData {
|
|
pub fn new(query: String, documents: Vec<String>, top_n: usize) -> Self {
|
|
Self {
|
|
query,
|
|
documents,
|
|
top_n,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub type RerankOutput = Vec<RerankResult>;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct RerankResult {
|
|
pub index: usize,
|
|
}
|
|
|
|
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool);
|
|
|
|
pub async fn create_config(
|
|
prompts: &[PromptAction<'static>],
|
|
client: &str,
|
|
vault: &Vault,
|
|
) -> Result<(String, Value)> {
|
|
let mut config = json!({
|
|
"type": client,
|
|
});
|
|
for (key, desc, help_message, is_secret) in prompts {
|
|
let env_name = format!("{client}_{key}").to_ascii_uppercase();
|
|
let required = std::env::var(&env_name).is_err();
|
|
let value = if !is_secret {
|
|
prompt_input_string(desc, required, *help_message)?
|
|
} else {
|
|
vault.add_secret(&env_name)?;
|
|
format!("{{{{{}}}}}", env_name)
|
|
};
|
|
if !value.is_empty() {
|
|
config[key] = value.into();
|
|
}
|
|
}
|
|
let model = set_client_models_config(&mut config, client).await?;
|
|
let clients = json!(vec![config]);
|
|
Ok((model, clients))
|
|
}
|
|
|
|
pub async fn create_openai_compatible_client_config(
|
|
client: &str,
|
|
) -> Result<Option<(String, Value)>> {
|
|
let api_base = OPENAI_COMPATIBLE_PROVIDERS
|
|
.into_iter()
|
|
.find(|(name, _)| client == *name)
|
|
.map(|(_, api_base)| api_base)
|
|
.unwrap_or("http(s)://{API_ADDR}/v1");
|
|
|
|
let name = if client == OpenAICompatibleClient::NAME {
|
|
let value = prompt_input_string("Provider Name", true, None)?;
|
|
value.replace(' ', "-")
|
|
} else {
|
|
client.to_string()
|
|
};
|
|
|
|
let mut config = json!({
|
|
"type": OpenAICompatibleClient::NAME,
|
|
"name": &name,
|
|
});
|
|
|
|
let api_base = if api_base.contains('{') {
|
|
prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))?
|
|
} else {
|
|
api_base.to_string()
|
|
};
|
|
config["api_base"] = api_base.into();
|
|
|
|
let api_key = prompt_input_string("API Key", false, None)?;
|
|
if !api_key.is_empty() {
|
|
config["api_key"] = api_key.into();
|
|
}
|
|
|
|
let model = set_client_models_config(&mut config, &name).await?;
|
|
let clients = json!(vec![config]);
|
|
Ok(Some((model, clients)))
|
|
}
|
|
|
|
pub async fn call_chat_completions(
|
|
input: &Input,
|
|
print: bool,
|
|
extract_code: bool,
|
|
client: &dyn Client,
|
|
abort_signal: AbortSignal,
|
|
) -> Result<(String, Vec<ToolResult>)> {
|
|
let is_child_agent = client.global_config().read().current_depth > 0;
|
|
let spinner_message = if is_child_agent { "" } else { "Generating" };
|
|
let ret = abortable_run_with_spinner(
|
|
client.chat_completions(input.clone()),
|
|
spinner_message,
|
|
abort_signal,
|
|
)
|
|
.await;
|
|
|
|
match ret {
|
|
Ok(ret) => {
|
|
let ChatCompletionsOutput {
|
|
mut text,
|
|
tool_calls,
|
|
..
|
|
} = ret;
|
|
if !text.is_empty() {
|
|
if extract_code {
|
|
text = extract_code_block(&strip_think_tag(&text)).to_string();
|
|
}
|
|
if print {
|
|
client.global_config().read().print_markdown(&text)?;
|
|
}
|
|
}
|
|
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
|
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
|
tool_results
|
|
.iter()
|
|
.for_each(|res| tracker.record_call(res.call.clone()));
|
|
}
|
|
Ok((text, tool_results))
|
|
}
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
|
|
pub async fn call_chat_completions_streaming(
|
|
input: &Input,
|
|
client: &dyn Client,
|
|
abort_signal: AbortSignal,
|
|
) -> Result<(String, Vec<ToolResult>)> {
|
|
let (tx, rx) = unbounded_channel();
|
|
let mut handler = SseHandler::new(tx, abort_signal.clone());
|
|
|
|
let (send_ret, render_ret) = tokio::join!(
|
|
client.chat_completions_streaming(input, &mut handler),
|
|
render_stream(rx, client.global_config(), abort_signal.clone()),
|
|
);
|
|
|
|
if handler.abort().aborted() {
|
|
bail!("Aborted.");
|
|
}
|
|
|
|
render_ret?;
|
|
|
|
let (text, tool_calls) = handler.take();
|
|
match send_ret {
|
|
Ok(_) => {
|
|
if !text.is_empty() && !text.ends_with('\n') {
|
|
println!();
|
|
}
|
|
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
|
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
|
tool_results
|
|
.iter()
|
|
.for_each(|res| tracker.record_call(res.call.clone()));
|
|
}
|
|
Ok((text, tool_results))
|
|
}
|
|
Err(err) => {
|
|
if !text.is_empty() {
|
|
println!();
|
|
}
|
|
Err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
|
|
bail!("The client doesn't support rerank api")
|
|
}
|
|
|
|
pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
|
bail!("The client doesn't support rerank api")
|
|
}
|
|
|
|
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
|
|
if (200..300).contains(&status) {
|
|
return Ok(());
|
|
}
|
|
debug!("Invalid response, status: {status}, data: {data}");
|
|
if let Some(error) = data["error"].as_object() {
|
|
if let (Some(typ), Some(message)) = (
|
|
json_str_from_map(error, "type"),
|
|
json_str_from_map(error, "message"),
|
|
) {
|
|
bail!("{message} (type: {typ})");
|
|
} else if let (Some(typ), Some(message)) = (
|
|
json_str_from_map(error, "code"),
|
|
json_str_from_map(error, "message"),
|
|
) {
|
|
bail!("{message} (code: {typ})");
|
|
}
|
|
} else if let Some(error) = data["errors"][0].as_object() {
|
|
if let (Some(code), Some(message)) = (
|
|
error.get("code").and_then(|v| v.as_u64()),
|
|
json_str_from_map(error, "message"),
|
|
) {
|
|
bail!("{message} (status: {code})")
|
|
}
|
|
} else if let Some(error) = data[0]["error"].as_object() {
|
|
if let (Some(status), Some(message)) = (
|
|
json_str_from_map(error, "status"),
|
|
json_str_from_map(error, "message"),
|
|
) {
|
|
bail!("{message} (status: {status})")
|
|
}
|
|
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
|
|
{
|
|
bail!("{detail} (status: {status})");
|
|
} else if let Some(error) = data["error"].as_str() {
|
|
bail!("{error}");
|
|
} else if let Some(message) = data["message"].as_str() {
|
|
bail!("{message}");
|
|
}
|
|
bail!("Invalid response data: {data} (status: {status})");
|
|
}
|
|
|
|
pub fn json_str_from_map<'a>(
|
|
map: &'a serde_json::Map<String, Value>,
|
|
field_name: &str,
|
|
) -> Option<&'a str> {
|
|
map.get(field_name).and_then(|v| v.as_str())
|
|
}
|
|
|
|
pub async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
|
|
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
|
|
let models: Vec<String> = provider
|
|
.models
|
|
.iter()
|
|
.filter(|v| v.model_type == "chat")
|
|
.map(|v| v.name.clone())
|
|
.collect();
|
|
let model_name = select_model(models)?;
|
|
return Ok(format!("{client}:{model_name}"));
|
|
}
|
|
let mut model_names = vec![];
|
|
if let (Some(true), Some(api_base), api_key) = (
|
|
client_config["type"]
|
|
.as_str()
|
|
.map(|v| v == OpenAICompatibleClient::NAME),
|
|
client_config["api_base"].as_str(),
|
|
client_config["api_key"]
|
|
.as_str()
|
|
.map(|v| v.to_string())
|
|
.or_else(|| {
|
|
let env_name = format!("{client}_api_key").to_ascii_uppercase();
|
|
std::env::var(&env_name).ok()
|
|
}),
|
|
) {
|
|
match abortable_run_with_spinner(
|
|
fetch_models(api_base, api_key.as_deref()),
|
|
"Fetching models",
|
|
create_abort_signal(),
|
|
)
|
|
.await
|
|
{
|
|
Ok(fetched_models) => {
|
|
model_names = MultiSelect::new("LLMs to include (required):", fetched_models)
|
|
.with_validator(|list: &[ListOption<&String>]| {
|
|
if list.is_empty() {
|
|
Ok(Validation::Invalid(
|
|
"At least one item must be selected".into(),
|
|
))
|
|
} else {
|
|
Ok(Validation::Valid)
|
|
}
|
|
})
|
|
.prompt()?;
|
|
}
|
|
Err(err) => {
|
|
eprintln!("✗ Fetch models failed: {err}");
|
|
}
|
|
}
|
|
}
|
|
if model_names.is_empty() {
|
|
model_names = prompt_input_string(
|
|
"LLMs to add",
|
|
true,
|
|
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
|
|
)?
|
|
.split(',')
|
|
.filter_map(|v| {
|
|
let v = v.trim();
|
|
if v.is_empty() {
|
|
None
|
|
} else {
|
|
Some(v.to_string())
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
}
|
|
if model_names.is_empty() {
|
|
bail!("No models");
|
|
}
|
|
let models: Vec<Value> = model_names
|
|
.iter()
|
|
.map(|v| {
|
|
let l = v.to_lowercase();
|
|
if l.contains("rank") {
|
|
json!({
|
|
"name": v,
|
|
"type": "reranker",
|
|
})
|
|
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
|
|
json!({
|
|
"name": v,
|
|
"type": "embedding",
|
|
"default_chunk_size": 1000,
|
|
"max_batch_size": 100
|
|
})
|
|
} else if v.contains("vision") {
|
|
json!({
|
|
"name": v,
|
|
"supports_vision": true
|
|
})
|
|
} else {
|
|
json!({
|
|
"name": v,
|
|
})
|
|
}
|
|
})
|
|
.collect();
|
|
client_config["models"] = models.into();
|
|
let model_name = select_model(model_names)?;
|
|
Ok(format!("{client}:{model_name}"))
|
|
}
|
|
|
|
fn select_model(model_names: Vec<String>) -> Result<String> {
|
|
if model_names.is_empty() {
|
|
bail!("No models");
|
|
}
|
|
let model = if model_names.len() == 1 {
|
|
model_names[0].clone()
|
|
} else {
|
|
Select::new("Default Model (required):", model_names).prompt()?
|
|
};
|
|
Ok(model)
|
|
}
|
|
|
|
fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result<String> {
|
|
let desc = if required {
|
|
format!("{desc} (required):")
|
|
} else {
|
|
format!("{desc} (optional):")
|
|
};
|
|
let mut text = Text::new(&desc);
|
|
if required {
|
|
text = text.with_validator(required!("This field is required"))
|
|
}
|
|
if let Some(help_message) = help_message {
|
|
text = text.with_help_message(help_message);
|
|
}
|
|
let text = text.prompt()?;
|
|
Ok(text)
|
|
}
|