testing
This commit is contained in:
+20
-21
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
|
||||
use crate::config::paths;
|
||||
use crate::{
|
||||
config::{Config, GlobalConfig, Input},
|
||||
config::{AppConfig, Input, RequestContext},
|
||||
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
|
||||
render::render_stream,
|
||||
utils::*,
|
||||
@@ -24,7 +25,7 @@ use tokio::sync::mpsc::unbounded_channel;
|
||||
pub const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||
|
||||
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
||||
Config::local_models_override()
|
||||
paths::local_models_override()
|
||||
.ok()
|
||||
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
|
||||
});
|
||||
@@ -37,7 +38,7 @@ static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Client: Sync + Send {
|
||||
fn global_config(&self) -> &GlobalConfig;
|
||||
fn app_config(&self) -> &AppConfig;
|
||||
|
||||
fn extra_config(&self) -> Option<&ExtraConfig>;
|
||||
|
||||
@@ -58,7 +59,7 @@ pub trait Client: Sync + Send {
|
||||
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
|
||||
builder = set_proxy(builder, proxy)?;
|
||||
}
|
||||
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
|
||||
if let Some(user_agent) = self.app_config().user_agent.as_ref() {
|
||||
builder = builder.user_agent(user_agent);
|
||||
}
|
||||
let client = builder
|
||||
@@ -69,7 +70,7 @@ pub trait Client: Sync + Send {
|
||||
}
|
||||
|
||||
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
|
||||
if self.global_config().read().dry_run {
|
||||
if self.app_config().dry_run {
|
||||
let content = input.echo_messages();
|
||||
return Ok(ChatCompletionsOutput::new(&content));
|
||||
}
|
||||
@@ -89,7 +90,7 @@ pub trait Client: Sync + Send {
|
||||
let input = input.clone();
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.global_config().read().dry_run {
|
||||
if self.app_config().dry_run {
|
||||
let content = input.echo_messages();
|
||||
handler.text(&content)?;
|
||||
return Ok(());
|
||||
@@ -413,9 +414,10 @@ pub async fn call_chat_completions(
|
||||
print: bool,
|
||||
extract_code: bool,
|
||||
client: &dyn Client,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let is_child_agent = client.global_config().read().current_depth > 0;
|
||||
let is_child_agent = ctx.current_depth() > 0;
|
||||
let spinner_message = if is_child_agent { "" } else { "Generating" };
|
||||
let ret = abortable_run_with_spinner(
|
||||
client.chat_completions(input.clone()),
|
||||
@@ -436,15 +438,13 @@ pub async fn call_chat_completions(
|
||||
text = extract_code_block(&strip_think_tag(&text)).to_string();
|
||||
}
|
||||
if print {
|
||||
client.global_config().read().print_markdown(&text)?;
|
||||
ctx.app.config.print_markdown(&text)?;
|
||||
}
|
||||
}
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
@@ -454,6 +454,7 @@ pub async fn call_chat_completions(
|
||||
pub async fn call_chat_completions_streaming(
|
||||
input: &Input,
|
||||
client: &dyn Client,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let (tx, rx) = unbounded_channel();
|
||||
@@ -461,7 +462,7 @@ pub async fn call_chat_completions_streaming(
|
||||
|
||||
let (send_ret, render_ret) = tokio::join!(
|
||||
client.chat_completions_streaming(input, &mut handler),
|
||||
render_stream(rx, client.global_config(), abort_signal.clone()),
|
||||
render_stream(rx, client.app_config(), abort_signal.clone()),
|
||||
);
|
||||
|
||||
if handler.abort().aborted() {
|
||||
@@ -476,12 +477,10 @@ pub async fn call_chat_completions_streaming(
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
println!();
|
||||
}
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => {
|
||||
|
||||
+11
-12
@@ -24,7 +24,7 @@ macro_rules! register_client {
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: $crate::config::GlobalConfig,
|
||||
app_config: std::sync::Arc<$crate::config::AppConfig>,
|
||||
config: $config,
|
||||
model: $crate::client::Model,
|
||||
}
|
||||
@@ -32,8 +32,8 @@ macro_rules! register_client {
|
||||
impl $client {
|
||||
pub const NAME: &'static str = $name;
|
||||
|
||||
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = global_config.read().clients.iter().find_map(|client_config| {
|
||||
pub fn init(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = app_config.clients.iter().find_map(|client_config| {
|
||||
if let ClientConfig::$config(c) = client_config {
|
||||
if Self::name(c) == model.client_name() {
|
||||
return Some(c.clone())
|
||||
@@ -43,7 +43,7 @@ macro_rules! register_client {
|
||||
})?;
|
||||
|
||||
Some(Box::new(Self {
|
||||
global_config: global_config.clone(),
|
||||
app_config: std::sync::Arc::clone(app_config),
|
||||
config,
|
||||
model: model.clone(),
|
||||
}))
|
||||
@@ -72,10 +72,9 @@ macro_rules! register_client {
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
|
||||
let model = model.unwrap_or_else(|| config.read().model.clone());
|
||||
pub fn init_client(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: $crate::client::Model) -> anyhow::Result<Box<dyn Client>> {
|
||||
None
|
||||
$(.or_else(|| $client::init(config, &model)))+
|
||||
$(.or_else(|| $client::init(app_config, &model)))+
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid model '{}'", model.id())
|
||||
})
|
||||
@@ -101,7 +100,7 @@ macro_rules! register_client {
|
||||
|
||||
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
|
||||
pub fn list_client_names(config: &$crate::config::AppConfig) -> Vec<&'static String> {
|
||||
let names = ALL_CLIENT_NAMES.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
@@ -117,7 +116,7 @@ macro_rules! register_client {
|
||||
|
||||
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
|
||||
pub fn list_all_models(config: &$crate::config::AppConfig) -> Vec<&'static $crate::client::Model> {
|
||||
let models = ALL_MODELS.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
@@ -131,7 +130,7 @@ macro_rules! register_client {
|
||||
models.iter().collect()
|
||||
}
|
||||
|
||||
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
pub fn list_models(config: &$crate::config::AppConfig, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
|
||||
}
|
||||
};
|
||||
@@ -140,8 +139,8 @@ macro_rules! register_client {
|
||||
#[macro_export]
|
||||
macro_rules! client_common_fns {
|
||||
() => {
|
||||
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
||||
&self.global_config
|
||||
fn app_config(&self) -> &$crate::config::AppConfig {
|
||||
&self.app_config
|
||||
}
|
||||
|
||||
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
||||
|
||||
+6
-2
@@ -3,7 +3,7 @@ use super::{
|
||||
message::{Message, MessageContent, MessageContentPart},
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config::AppConfig;
|
||||
use crate::utils::{estimate_token_length, strip_think_tag};
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
@@ -44,7 +44,11 @@ impl Model {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
|
||||
pub fn retrieve_model(
|
||||
config: &AppConfig,
|
||||
model_id: &str,
|
||||
model_type: ModelType,
|
||||
) -> Result<Self> {
|
||||
let models = list_all_models(config);
|
||||
let (client_name, model_name) = match model_id.split_once(':') {
|
||||
Some((client_name, model_name)) => {
|
||||
|
||||
+3
-3
@@ -1,6 +1,6 @@
|
||||
use super::ClientConfig;
|
||||
use super::access_token::{is_valid_access_token, set_access_token};
|
||||
use crate::config::Config;
|
||||
use crate::config::paths;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
@@ -178,13 +178,13 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
||||
}
|
||||
|
||||
pub fn load_oauth_tokens(client_name: &str) -> Option<OAuthTokens> {
|
||||
let path = Config::token_file(client_name);
|
||||
let path = paths::token_file(client_name);
|
||||
let content = fs::read_to_string(path).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
|
||||
fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
|
||||
let path = Config::token_file(client_name);
|
||||
let path = paths::token_file(client_name);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user