This commit is contained in:
2026-04-16 10:17:03 -06:00
parent ff3419a714
commit d906713d7d
82 changed files with 14886 additions and 3310 deletions
+20 -21
View File
@@ -1,7 +1,8 @@
use super::*;
use crate::config::paths;
use crate::{
config::{Config, GlobalConfig, Input},
config::{AppConfig, Input, RequestContext},
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
render::render_stream,
utils::*,
@@ -24,7 +25,7 @@ use tokio::sync::mpsc::unbounded_channel;
pub const MODELS_YAML: &str = include_str!("../../models.yaml");
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
Config::local_models_override()
paths::local_models_override()
.ok()
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
});
@@ -37,7 +38,7 @@ static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/
#[async_trait::async_trait]
pub trait Client: Sync + Send {
fn global_config(&self) -> &GlobalConfig;
fn app_config(&self) -> &AppConfig;
fn extra_config(&self) -> Option<&ExtraConfig>;
@@ -58,7 +59,7 @@ pub trait Client: Sync + Send {
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
builder = set_proxy(builder, proxy)?;
}
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
if let Some(user_agent) = self.app_config().user_agent.as_ref() {
builder = builder.user_agent(user_agent);
}
let client = builder
@@ -69,7 +70,7 @@ pub trait Client: Sync + Send {
}
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
if self.global_config().read().dry_run {
if self.app_config().dry_run {
let content = input.echo_messages();
return Ok(ChatCompletionsOutput::new(&content));
}
@@ -89,7 +90,7 @@ pub trait Client: Sync + Send {
let input = input.clone();
tokio::select! {
ret = async {
if self.global_config().read().dry_run {
if self.app_config().dry_run {
let content = input.echo_messages();
handler.text(&content)?;
return Ok(());
@@ -413,9 +414,10 @@ pub async fn call_chat_completions(
print: bool,
extract_code: bool,
client: &dyn Client,
ctx: &mut RequestContext,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let is_child_agent = client.global_config().read().current_depth > 0;
let is_child_agent = ctx.current_depth() > 0;
let spinner_message = if is_child_agent { "" } else { "Generating" };
let ret = abortable_run_with_spinner(
client.chat_completions(input.clone()),
@@ -436,15 +438,13 @@ pub async fn call_chat_completions(
text = extract_code_block(&strip_think_tag(&text)).to_string();
}
if print {
client.global_config().read().print_markdown(&text)?;
ctx.app.config.print_markdown(&text)?;
}
}
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
tool_results
.iter()
.for_each(|res| tracker.record_call(res.call.clone()));
}
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
tool_results
.iter()
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
Ok((text, tool_results))
}
Err(err) => Err(err),
@@ -454,6 +454,7 @@ pub async fn call_chat_completions(
pub async fn call_chat_completions_streaming(
input: &Input,
client: &dyn Client,
ctx: &mut RequestContext,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel();
@@ -461,7 +462,7 @@ pub async fn call_chat_completions_streaming(
let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, client.global_config(), abort_signal.clone()),
render_stream(rx, client.app_config(), abort_signal.clone()),
);
if handler.abort().aborted() {
@@ -476,12 +477,10 @@ pub async fn call_chat_completions_streaming(
if !text.is_empty() && !text.ends_with('\n') {
println!();
}
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
tool_results
.iter()
.for_each(|res| tracker.record_call(res.call.clone()));
}
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
tool_results
.iter()
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
Ok((text, tool_results))
}
Err(err) => {
+11 -12
View File
@@ -24,7 +24,7 @@ macro_rules! register_client {
$(
#[derive(Debug)]
pub struct $client {
global_config: $crate::config::GlobalConfig,
app_config: std::sync::Arc<$crate::config::AppConfig>,
config: $config,
model: $crate::client::Model,
}
@@ -32,8 +32,8 @@ macro_rules! register_client {
impl $client {
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
pub fn init(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = app_config.clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == model.client_name() {
return Some(c.clone())
@@ -43,7 +43,7 @@ macro_rules! register_client {
})?;
Some(Box::new(Self {
global_config: global_config.clone(),
app_config: std::sync::Arc::clone(app_config),
config,
model: model.clone(),
}))
@@ -72,10 +72,9 @@ macro_rules! register_client {
)+
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
pub fn init_client(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: $crate::client::Model) -> anyhow::Result<Box<dyn Client>> {
None
$(.or_else(|| $client::init(config, &model)))+
$(.or_else(|| $client::init(app_config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Invalid model '{}'", model.id())
})
@@ -101,7 +100,7 @@ macro_rules! register_client {
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
pub fn list_client_names(config: &$crate::config::AppConfig) -> Vec<&'static String> {
let names = ALL_CLIENT_NAMES.get_or_init(|| {
config
.clients
@@ -117,7 +116,7 @@ macro_rules! register_client {
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
pub fn list_all_models(config: &$crate::config::AppConfig) -> Vec<&'static $crate::client::Model> {
let models = ALL_MODELS.get_or_init(|| {
config
.clients
@@ -131,7 +130,7 @@ macro_rules! register_client {
models.iter().collect()
}
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
pub fn list_models(config: &$crate::config::AppConfig, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
}
};
@@ -140,8 +139,8 @@ macro_rules! register_client {
#[macro_export]
macro_rules! client_common_fns {
() => {
fn global_config(&self) -> &$crate::config::GlobalConfig {
&self.global_config
fn app_config(&self) -> &$crate::config::AppConfig {
&self.app_config
}
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
+6 -2
View File
@@ -3,7 +3,7 @@ use super::{
message::{Message, MessageContent, MessageContentPart},
};
use crate::config::Config;
use crate::config::AppConfig;
use crate::utils::{estimate_token_length, strip_think_tag};
use anyhow::{Result, bail};
@@ -44,7 +44,11 @@ impl Model {
.collect()
}
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
pub fn retrieve_model(
config: &AppConfig,
model_id: &str,
model_type: ModelType,
) -> Result<Self> {
let models = list_all_models(config);
let (client_name, model_name) = match model_id.split_once(':') {
Some((client_name, model_name)) => {
+3 -3
View File
@@ -1,6 +1,6 @@
use super::ClientConfig;
use super::access_token::{is_valid_access_token, set_access_token};
use crate::config::Config;
use crate::config::paths;
use anyhow::{Result, anyhow, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
@@ -178,13 +178,13 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
}
pub fn load_oauth_tokens(client_name: &str) -> Option<OAuthTokens> {
let path = Config::token_file(client_name);
let path = paths::token_file(client_name);
let content = fs::read_to_string(path).ok()?;
serde_json::from_str(&content).ok()
}
fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
let path = Config::token_file(client_name);
let path = paths::token_file(client_name);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}