feat: 99% complete migration to new state structs to get away from God-Config struct; i.e. AppConfig, AppState, and RequestContext
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use crate::client::{ModelType, list_models};
|
||||
use crate::config::paths;
|
||||
use crate::config::{Config, list_agents};
|
||||
use clap_complete::{CompletionCandidate, Shell, generate};
|
||||
use clap_complete_nushell::Nushell;
|
||||
@@ -33,7 +34,7 @@ impl ShellCompletion {
|
||||
pub(super) fn model_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
match Config::init_bare() {
|
||||
Ok(config) => list_models(&config, ModelType::Chat)
|
||||
Ok(config) => list_models(&config.to_app_config(), ModelType::Chat)
|
||||
.into_iter()
|
||||
.filter(|&m| m.id().starts_with(&*cur))
|
||||
.map(|m| CompletionCandidate::new(m.id()))
|
||||
@@ -44,7 +45,7 @@ pub(super) fn model_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
|
||||
pub(super) fn role_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_roles(true)
|
||||
paths::list_roles(true)
|
||||
.into_iter()
|
||||
.filter(|r| r.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
@@ -62,7 +63,7 @@ pub(super) fn agent_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
|
||||
pub(super) fn rag_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_rags()
|
||||
paths::list_rags()
|
||||
.into_iter()
|
||||
.filter(|r| r.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
@@ -71,7 +72,7 @@ pub(super) fn rag_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
|
||||
pub(super) fn macro_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_macros()
|
||||
paths::list_macros()
|
||||
.into_iter()
|
||||
.filter(|m| m.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
|
||||
+20
-21
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
|
||||
use crate::config::paths;
|
||||
use crate::{
|
||||
config::{Config, GlobalConfig, Input},
|
||||
config::{AppConfig, Input, RequestContext},
|
||||
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
|
||||
render::render_stream,
|
||||
utils::*,
|
||||
@@ -24,7 +25,7 @@ use tokio::sync::mpsc::unbounded_channel;
|
||||
pub const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||
|
||||
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
||||
Config::local_models_override()
|
||||
paths::local_models_override()
|
||||
.ok()
|
||||
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
|
||||
});
|
||||
@@ -37,7 +38,7 @@ static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Client: Sync + Send {
|
||||
fn global_config(&self) -> &GlobalConfig;
|
||||
fn app_config(&self) -> &AppConfig;
|
||||
|
||||
fn extra_config(&self) -> Option<&ExtraConfig>;
|
||||
|
||||
@@ -58,7 +59,7 @@ pub trait Client: Sync + Send {
|
||||
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
|
||||
builder = set_proxy(builder, proxy)?;
|
||||
}
|
||||
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
|
||||
if let Some(user_agent) = self.app_config().user_agent.as_ref() {
|
||||
builder = builder.user_agent(user_agent);
|
||||
}
|
||||
let client = builder
|
||||
@@ -69,7 +70,7 @@ pub trait Client: Sync + Send {
|
||||
}
|
||||
|
||||
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
|
||||
if self.global_config().read().dry_run {
|
||||
if self.app_config().dry_run {
|
||||
let content = input.echo_messages();
|
||||
return Ok(ChatCompletionsOutput::new(&content));
|
||||
}
|
||||
@@ -89,7 +90,7 @@ pub trait Client: Sync + Send {
|
||||
let input = input.clone();
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.global_config().read().dry_run {
|
||||
if self.app_config().dry_run {
|
||||
let content = input.echo_messages();
|
||||
handler.text(&content)?;
|
||||
return Ok(());
|
||||
@@ -413,9 +414,10 @@ pub async fn call_chat_completions(
|
||||
print: bool,
|
||||
extract_code: bool,
|
||||
client: &dyn Client,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let is_child_agent = client.global_config().read().current_depth > 0;
|
||||
let is_child_agent = ctx.current_depth > 0;
|
||||
let spinner_message = if is_child_agent { "" } else { "Generating" };
|
||||
let ret = abortable_run_with_spinner(
|
||||
client.chat_completions(input.clone()),
|
||||
@@ -436,15 +438,13 @@ pub async fn call_chat_completions(
|
||||
text = extract_code_block(&strip_think_tag(&text)).to_string();
|
||||
}
|
||||
if print {
|
||||
client.global_config().read().print_markdown(&text)?;
|
||||
ctx.app.config.print_markdown(&text)?;
|
||||
}
|
||||
}
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
@@ -454,6 +454,7 @@ pub async fn call_chat_completions(
|
||||
pub async fn call_chat_completions_streaming(
|
||||
input: &Input,
|
||||
client: &dyn Client,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let (tx, rx) = unbounded_channel();
|
||||
@@ -461,7 +462,7 @@ pub async fn call_chat_completions_streaming(
|
||||
|
||||
let (send_ret, render_ret) = tokio::join!(
|
||||
client.chat_completions_streaming(input, &mut handler),
|
||||
render_stream(rx, client.global_config(), abort_signal.clone()),
|
||||
render_stream(rx, client.app_config(), abort_signal.clone()),
|
||||
);
|
||||
|
||||
if handler.abort().aborted() {
|
||||
@@ -476,12 +477,10 @@ pub async fn call_chat_completions_streaming(
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
println!();
|
||||
}
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
let tool_results = eval_tool_calls(ctx, tool_calls).await?;
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| ctx.tool_scope.tool_tracker.record_call(res.call.clone()));
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => {
|
||||
|
||||
+11
-12
@@ -24,7 +24,7 @@ macro_rules! register_client {
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: $crate::config::GlobalConfig,
|
||||
app_config: std::sync::Arc<$crate::config::AppConfig>,
|
||||
config: $config,
|
||||
model: $crate::client::Model,
|
||||
}
|
||||
@@ -32,8 +32,8 @@ macro_rules! register_client {
|
||||
impl $client {
|
||||
pub const NAME: &'static str = $name;
|
||||
|
||||
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = global_config.read().clients.iter().find_map(|client_config| {
|
||||
pub fn init(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = app_config.clients.iter().find_map(|client_config| {
|
||||
if let ClientConfig::$config(c) = client_config {
|
||||
if Self::name(c) == model.client_name() {
|
||||
return Some(c.clone())
|
||||
@@ -43,7 +43,7 @@ macro_rules! register_client {
|
||||
})?;
|
||||
|
||||
Some(Box::new(Self {
|
||||
global_config: global_config.clone(),
|
||||
app_config: std::sync::Arc::clone(app_config),
|
||||
config,
|
||||
model: model.clone(),
|
||||
}))
|
||||
@@ -72,10 +72,9 @@ macro_rules! register_client {
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
|
||||
let model = model.unwrap_or_else(|| config.read().model.clone());
|
||||
pub fn init_client(app_config: &std::sync::Arc<$crate::config::AppConfig>, model: $crate::client::Model) -> anyhow::Result<Box<dyn Client>> {
|
||||
None
|
||||
$(.or_else(|| $client::init(config, &model)))+
|
||||
$(.or_else(|| $client::init(app_config, &model)))+
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid model '{}'", model.id())
|
||||
})
|
||||
@@ -101,7 +100,7 @@ macro_rules! register_client {
|
||||
|
||||
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
|
||||
pub fn list_client_names(config: &$crate::config::AppConfig) -> Vec<&'static String> {
|
||||
let names = ALL_CLIENT_NAMES.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
@@ -117,7 +116,7 @@ macro_rules! register_client {
|
||||
|
||||
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
|
||||
pub fn list_all_models(config: &$crate::config::AppConfig) -> Vec<&'static $crate::client::Model> {
|
||||
let models = ALL_MODELS.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
@@ -131,7 +130,7 @@ macro_rules! register_client {
|
||||
models.iter().collect()
|
||||
}
|
||||
|
||||
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
pub fn list_models(config: &$crate::config::AppConfig, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
|
||||
}
|
||||
};
|
||||
@@ -140,8 +139,8 @@ macro_rules! register_client {
|
||||
#[macro_export]
|
||||
macro_rules! client_common_fns {
|
||||
() => {
|
||||
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
||||
&self.global_config
|
||||
fn app_config(&self) -> &$crate::config::AppConfig {
|
||||
&self.app_config
|
||||
}
|
||||
|
||||
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
||||
|
||||
+6
-2
@@ -3,7 +3,7 @@ use super::{
|
||||
message::{Message, MessageContent, MessageContentPart},
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config::AppConfig;
|
||||
use crate::utils::{estimate_token_length, strip_think_tag};
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
@@ -44,7 +44,11 @@ impl Model {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
|
||||
pub fn retrieve_model(
|
||||
config: &AppConfig,
|
||||
model_id: &str,
|
||||
model_type: ModelType,
|
||||
) -> Result<Self> {
|
||||
let models = list_all_models(config);
|
||||
let (client_name, model_name) = match model_id.split_once(':') {
|
||||
Some((client_name, model_name)) => {
|
||||
|
||||
+3
-3
@@ -1,6 +1,6 @@
|
||||
use super::ClientConfig;
|
||||
use super::access_token::{is_valid_access_token, set_access_token};
|
||||
use crate::config::Config;
|
||||
use crate::config::paths;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
@@ -178,13 +178,13 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
||||
}
|
||||
|
||||
pub fn load_oauth_tokens(client_name: &str) -> Option<OAuthTokens> {
|
||||
let path = Config::token_file(client_name);
|
||||
let path = paths::token_file(client_name);
|
||||
let content = fs::read_to_string(path).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
|
||||
fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
|
||||
let path = Config::token_file(client_name);
|
||||
let path = paths::token_file(client_name);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
+126
-106
@@ -1,4 +1,3 @@
|
||||
use super::todo::TodoList;
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
@@ -6,6 +5,7 @@ use crate::{
|
||||
function::{Functions, run_llm_function},
|
||||
};
|
||||
|
||||
use crate::config::paths;
|
||||
use crate::config::prompts::{
|
||||
DEFAULT_SPAWN_INSTRUCTIONS, DEFAULT_TEAMMATE_INSTRUCTIONS, DEFAULT_TODO_INSTRUCTIONS,
|
||||
DEFAULT_USER_INTERACTION_INSTRUCTIONS,
|
||||
@@ -38,16 +38,13 @@ pub struct Agent {
|
||||
rag: Option<Arc<Rag>>,
|
||||
model: Model,
|
||||
vault: GlobalVault,
|
||||
todo_list: TodoList,
|
||||
continuation_count: usize,
|
||||
last_continuation_response: Option<String>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn install_builtin_agents() -> Result<()> {
|
||||
info!(
|
||||
"Installing built-in agents in {}",
|
||||
Config::agents_data_dir().display()
|
||||
paths::agents_data_dir().display()
|
||||
);
|
||||
|
||||
for file in AgentAssets::iter() {
|
||||
@@ -56,7 +53,7 @@ impl Agent {
|
||||
let embedded_file = AgentAssets::get(&file)
|
||||
.ok_or_else(|| anyhow!("Failed to load embedded agent file: {}", file.as_ref()))?;
|
||||
let content = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let file_path = Config::agents_data_dir().join(file.as_ref());
|
||||
let file_path = paths::agents_data_dir().join(file.as_ref());
|
||||
let file_extension = file_path
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
@@ -88,14 +85,17 @@ impl Agent {
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
app_state: &AppState,
|
||||
current_model: &Model,
|
||||
info_flag: bool,
|
||||
name: &str,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
let agent_data_dir = Config::agent_data_dir(name);
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
||||
let config_path = Config::agent_config_file(name);
|
||||
let agent_data_dir = paths::agent_data_dir(name);
|
||||
let loaders = app.document_loaders.clone();
|
||||
let rag_path = paths::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
||||
let config_path = paths::agent_config_file(name);
|
||||
let mut agent_config = if config_path.exists() {
|
||||
AgentConfig::load(&config_path)?
|
||||
} else {
|
||||
@@ -103,57 +103,24 @@ impl Agent {
|
||||
};
|
||||
let mut functions = Functions::init_agent(name, &agent_config.global_tools)?;
|
||||
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
let mcp_servers = if config.read().mcp_server_support {
|
||||
(!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(","))
|
||||
} else {
|
||||
eprintln!(
|
||||
"{}",
|
||||
formatdoc!(
|
||||
"
|
||||
This agent uses MCP servers, but MCP support is disabled.
|
||||
To enable it, exit the agent and set 'mcp_server_support: true', then try again
|
||||
"
|
||||
)
|
||||
);
|
||||
None
|
||||
};
|
||||
agent_config.load_envs(app);
|
||||
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.with_context(|| "MCP registry should be populated")?;
|
||||
let new_mcp_registry =
|
||||
McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?;
|
||||
|
||||
if !new_mcp_registry.is_empty() {
|
||||
functions.append_mcp_meta_functions(new_mcp_registry.list_started_servers());
|
||||
}
|
||||
|
||||
config.write().mcp_registry = Some(new_mcp_registry);
|
||||
|
||||
agent_config.load_envs(&config.read());
|
||||
|
||||
let model = {
|
||||
let config = config.read();
|
||||
match agent_config.model_id.as_ref() {
|
||||
Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?,
|
||||
None => {
|
||||
if agent_config.temperature.is_none() {
|
||||
agent_config.temperature = config.temperature;
|
||||
}
|
||||
if agent_config.top_p.is_none() {
|
||||
agent_config.top_p = config.top_p;
|
||||
}
|
||||
config.current_model().clone()
|
||||
let model = match agent_config.model_id.as_ref() {
|
||||
Some(model_id) => Model::retrieve_model(app, model_id, ModelType::Chat)?,
|
||||
None => {
|
||||
if agent_config.temperature.is_none() {
|
||||
agent_config.temperature = app.temperature;
|
||||
}
|
||||
if agent_config.top_p.is_none() {
|
||||
agent_config.top_p = app.top_p;
|
||||
}
|
||||
current_model.clone()
|
||||
}
|
||||
};
|
||||
|
||||
let rag = if rag_path.exists() {
|
||||
Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?))
|
||||
} else if !agent_config.documents.is_empty() && !config.read().info_flag {
|
||||
Some(Arc::new(Rag::load(app, DEFAULT_AGENT_NAME, &rag_path)?))
|
||||
} else if !agent_config.documents.is_empty() && !info_flag {
|
||||
let mut ans = false;
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
ans = Confirm::new("The agent has documents attached, init RAG?")
|
||||
@@ -185,8 +152,7 @@ impl Agent {
|
||||
document_paths.push(path.to_string())
|
||||
}
|
||||
}
|
||||
let rag =
|
||||
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
|
||||
let rag = Rag::init(app, "rag", &rag_path, &document_paths, abort_signal).await?;
|
||||
Some(Arc::new(rag))
|
||||
} else {
|
||||
None
|
||||
@@ -218,10 +184,7 @@ impl Agent {
|
||||
functions,
|
||||
rag,
|
||||
model,
|
||||
vault: Arc::clone(&config.read().vault),
|
||||
todo_list: TodoList::default(),
|
||||
continuation_count: 0,
|
||||
last_continuation_response: None,
|
||||
vault: app_state.vault.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -295,11 +258,11 @@ impl Agent {
|
||||
let mut config = self.config.clone();
|
||||
config.instructions = self.interpolated_instructions();
|
||||
value["definition"] = json!(config);
|
||||
value["data_dir"] = Config::agent_data_dir(&self.name)
|
||||
value["data_dir"] = paths::agent_data_dir(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
value["config_file"] = Config::agent_config_file(&self.name)
|
||||
value["config_file"] = paths::agent_config_file(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
@@ -323,6 +286,14 @@ impl Agent {
|
||||
self.rag.clone()
|
||||
}
|
||||
|
||||
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
|
||||
self.functions.append_mcp_meta_functions(mcp_servers);
|
||||
}
|
||||
|
||||
pub fn mcp_server_names(&self) -> &[String] {
|
||||
&self.config.mcp_servers
|
||||
}
|
||||
|
||||
pub fn conversation_starters(&self) -> Vec<String> {
|
||||
self.config
|
||||
.conversation_starters
|
||||
@@ -443,44 +414,6 @@ impl Agent {
|
||||
self.config.escalation_timeout
|
||||
}
|
||||
|
||||
pub fn continuation_count(&self) -> usize {
|
||||
self.continuation_count
|
||||
}
|
||||
|
||||
pub fn increment_continuation(&mut self) {
|
||||
self.continuation_count += 1;
|
||||
}
|
||||
|
||||
pub fn reset_continuation(&mut self) {
|
||||
self.continuation_count = 0;
|
||||
self.last_continuation_response = None;
|
||||
}
|
||||
|
||||
pub fn set_last_continuation_response(&mut self, response: String) {
|
||||
self.last_continuation_response = Some(response);
|
||||
}
|
||||
|
||||
pub fn todo_list(&self) -> &TodoList {
|
||||
&self.todo_list
|
||||
}
|
||||
|
||||
pub fn init_todo_list(&mut self, goal: &str) {
|
||||
self.todo_list = TodoList::new(goal);
|
||||
}
|
||||
|
||||
pub fn add_todo(&mut self, task: &str) -> usize {
|
||||
self.todo_list.add(task)
|
||||
}
|
||||
|
||||
pub fn mark_todo_done(&mut self, id: usize) -> bool {
|
||||
self.todo_list.mark_done(id)
|
||||
}
|
||||
|
||||
pub fn clear_todo_list(&mut self) {
|
||||
self.todo_list.clear();
|
||||
self.reset_continuation();
|
||||
}
|
||||
|
||||
pub fn continuation_prompt(&self) -> String {
|
||||
self.config.continuation_prompt.clone().unwrap_or_else(|| {
|
||||
formatdoc! {"
|
||||
@@ -696,12 +629,12 @@ impl AgentConfig {
|
||||
Ok(agent_config)
|
||||
}
|
||||
|
||||
fn load_envs(&mut self, config: &Config) {
|
||||
fn load_envs(&mut self, app: &AppConfig) {
|
||||
let name = &self.name;
|
||||
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));
|
||||
|
||||
if self.agent_session.is_none() {
|
||||
self.agent_session = config.agent_session.clone();
|
||||
self.agent_session = app.agent_session.clone();
|
||||
}
|
||||
|
||||
if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
|
||||
@@ -793,7 +726,7 @@ pub struct AgentVariable {
|
||||
}
|
||||
|
||||
pub fn list_agents() -> Vec<String> {
|
||||
let agents_data_dir = Config::agents_data_dir();
|
||||
let agents_data_dir = paths::agents_data_dir();
|
||||
if !agents_data_dir.exists() {
|
||||
return vec![];
|
||||
}
|
||||
@@ -803,6 +736,7 @@ pub fn list_agents() -> Vec<String> {
|
||||
for entry in entries.flatten() {
|
||||
if entry.path().is_dir()
|
||||
&& let Some(name) = entry.file_name().to_str()
|
||||
&& !name.starts_with('.')
|
||||
{
|
||||
agents.push(name.to_string());
|
||||
}
|
||||
@@ -813,7 +747,7 @@ pub fn list_agents() -> Vec<String> {
|
||||
}
|
||||
|
||||
pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>)> {
|
||||
let config_path = Config::agent_config_file(agent_name);
|
||||
let config_path = paths::agent_config_file(agent_name);
|
||||
if !config_path.exists() {
|
||||
return vec![];
|
||||
}
|
||||
@@ -832,3 +766,89 @@ pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn agent_config_parses_from_yaml() {
|
||||
let yaml = r#"
|
||||
name: test-agent
|
||||
description: A test agent
|
||||
instructions: You are helpful
|
||||
auto_continue: true
|
||||
max_auto_continues: 5
|
||||
can_spawn_agents: true
|
||||
max_concurrent_agents: 8
|
||||
max_agent_depth: 2
|
||||
mcp_servers:
|
||||
- github
|
||||
- jira
|
||||
global_tools:
|
||||
- execute_command.sh
|
||||
- fs_read.sh
|
||||
conversation_starters:
|
||||
- "Hello!"
|
||||
- "How are you?"
|
||||
variables:
|
||||
- name: username
|
||||
description: Your name
|
||||
"#;
|
||||
|
||||
let config: AgentConfig = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(config.name, "test-agent");
|
||||
assert_eq!(config.description, "A test agent");
|
||||
assert!(config.auto_continue);
|
||||
assert_eq!(config.max_auto_continues, 5);
|
||||
assert!(config.can_spawn_agents);
|
||||
assert_eq!(config.max_concurrent_agents, 8);
|
||||
assert_eq!(config.max_agent_depth, 2);
|
||||
assert_eq!(config.mcp_servers, vec!["github", "jira"]);
|
||||
assert_eq!(config.global_tools.len(), 2);
|
||||
assert_eq!(config.conversation_starters.len(), 2);
|
||||
assert_eq!(config.variables.len(), 1);
|
||||
assert_eq!(config.variables[0].name, "username");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_config_defaults() {
|
||||
let yaml = "name: minimal\ninstructions: hi\n";
|
||||
let config: AgentConfig = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(config.name, "minimal");
|
||||
assert!(!config.auto_continue);
|
||||
assert!(!config.can_spawn_agents);
|
||||
assert_eq!(config.max_concurrent_agents, 4);
|
||||
assert_eq!(config.max_agent_depth, 3);
|
||||
assert_eq!(config.max_auto_continues, 10);
|
||||
assert!(config.mcp_servers.is_empty());
|
||||
assert!(config.global_tools.is_empty());
|
||||
assert!(config.conversation_starters.is_empty());
|
||||
assert!(config.variables.is_empty());
|
||||
assert!(config.model_id.is_none());
|
||||
assert!(config.temperature.is_none());
|
||||
assert!(config.top_p.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_config_with_model() {
|
||||
let yaml =
|
||||
"name: test\nmodel: openai:gpt-4\ntemperature: 0.7\ntop_p: 0.9\ninstructions: hi\n";
|
||||
let config: AgentConfig = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(config.model_id, Some("openai:gpt-4".to_string()));
|
||||
assert_eq!(config.temperature, Some(0.7));
|
||||
assert_eq!(config.top_p, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_config_inject_defaults_true() {
|
||||
let yaml = "name: test\ninstructions: hi\n";
|
||||
let config: AgentConfig = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert!(config.inject_todo_instructions);
|
||||
assert!(config.inject_spawn_instructions);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,586 @@
|
||||
//! Immutable, server-wide application configuration.
|
||||
//!
|
||||
//! `AppConfig` contains the settings loaded from `config.yaml` that are
|
||||
//! global to the Loki process: LLM provider configs, UI preferences, tool
|
||||
//! and MCP settings, RAG defaults, etc.
|
||||
//!
|
||||
//! This is Phase 1, Step 0 of the REST API refactor: the struct is
|
||||
//! introduced alongside the existing [`Config`](super::Config) and is not
|
||||
//! yet wired into the runtime. See `docs/PHASE-1-IMPLEMENTATION-PLAN.md`
|
||||
//! for the full migration plan.
|
||||
//!
|
||||
//! # Relationship to `Config`
|
||||
//!
|
||||
//! `AppConfig` mirrors the **serialized** fields of [`Config`] — that is,
|
||||
//! every field that is NOT marked `#[serde(skip)]`. The deserialization
|
||||
//! shape is identical so an existing `config.yaml` can be loaded into
|
||||
//! either type without modification.
|
||||
//!
|
||||
//! Runtime-only state (current role, session, agent, supervisor, etc.)
|
||||
//! lives on [`RequestContext`](super::request_context::RequestContext).
|
||||
|
||||
use crate::client::ClientConfig;
|
||||
use crate::render::{MarkdownRender, RenderOptions};
|
||||
use crate::utils::{IS_STDOUT_TERMINAL, NO_COLOR, decode_bin, get_env_name};
|
||||
|
||||
use super::paths;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use indexmap::IndexMap;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use syntect::highlighting::ThemeSet;
|
||||
use terminal_colorsaurus::{ColorScheme, QueryOptions, color_scheme};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct AppConfig {
|
||||
#[serde(rename(serialize = "model", deserialize = "model"))]
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub model_id: String,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
|
||||
pub dry_run: bool,
|
||||
pub stream: bool,
|
||||
pub save: bool,
|
||||
pub keybindings: String,
|
||||
pub editor: Option<String>,
|
||||
pub wrap: Option<String>,
|
||||
pub wrap_code: bool,
|
||||
pub(crate) vault_password_file: Option<PathBuf>,
|
||||
|
||||
pub function_calling_support: bool,
|
||||
pub mapping_tools: IndexMap<String, String>,
|
||||
pub enabled_tools: Option<String>,
|
||||
pub visible_tools: Option<Vec<String>>,
|
||||
|
||||
pub mcp_server_support: bool,
|
||||
pub mapping_mcp_servers: IndexMap<String, String>,
|
||||
pub enabled_mcp_servers: Option<String>,
|
||||
|
||||
pub repl_prelude: Option<String>,
|
||||
pub cmd_prelude: Option<String>,
|
||||
pub agent_session: Option<String>,
|
||||
|
||||
pub save_session: Option<bool>,
|
||||
pub compression_threshold: usize,
|
||||
pub summarization_prompt: Option<String>,
|
||||
pub summary_context_prompt: Option<String>,
|
||||
|
||||
pub rag_embedding_model: Option<String>,
|
||||
pub rag_reranker_model: Option<String>,
|
||||
pub rag_top_k: usize,
|
||||
pub rag_chunk_size: Option<usize>,
|
||||
pub rag_chunk_overlap: Option<usize>,
|
||||
pub rag_template: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub document_loaders: HashMap<String, String>,
|
||||
|
||||
pub highlight: bool,
|
||||
pub theme: Option<String>,
|
||||
pub left_prompt: Option<String>,
|
||||
pub right_prompt: Option<String>,
|
||||
|
||||
pub user_agent: Option<String>,
|
||||
pub save_shell_history: bool,
|
||||
pub sync_models_url: Option<String>,
|
||||
|
||||
pub clients: Vec<ClientConfig>,
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: Default::default(),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
|
||||
dry_run: false,
|
||||
stream: true,
|
||||
save: false,
|
||||
keybindings: "emacs".into(),
|
||||
editor: None,
|
||||
wrap: None,
|
||||
wrap_code: false,
|
||||
vault_password_file: None,
|
||||
|
||||
function_calling_support: true,
|
||||
mapping_tools: Default::default(),
|
||||
enabled_tools: None,
|
||||
visible_tools: None,
|
||||
|
||||
mcp_server_support: true,
|
||||
mapping_mcp_servers: Default::default(),
|
||||
enabled_mcp_servers: None,
|
||||
|
||||
repl_prelude: None,
|
||||
cmd_prelude: None,
|
||||
agent_session: None,
|
||||
|
||||
save_session: None,
|
||||
compression_threshold: 4000,
|
||||
summarization_prompt: None,
|
||||
summary_context_prompt: None,
|
||||
|
||||
rag_embedding_model: None,
|
||||
rag_reranker_model: None,
|
||||
rag_top_k: 5,
|
||||
rag_chunk_size: None,
|
||||
rag_chunk_overlap: None,
|
||||
rag_template: None,
|
||||
|
||||
document_loaders: Default::default(),
|
||||
|
||||
highlight: true,
|
||||
theme: None,
|
||||
left_prompt: None,
|
||||
right_prompt: None,
|
||||
|
||||
user_agent: None,
|
||||
save_shell_history: true,
|
||||
sync_models_url: None,
|
||||
|
||||
clients: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn vault_password_file(&self) -> PathBuf {
|
||||
match &self.vault_password_file {
|
||||
Some(path) => match path.exists() {
|
||||
true => path.clone(),
|
||||
false => gman::config::Config::local_provider_password_file(),
|
||||
},
|
||||
None => gman::config::Config::local_provider_password_file(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn editor(&self) -> Result<String> {
|
||||
super::EDITOR.get_or_init(move || {
|
||||
let editor = self.editor.clone()
|
||||
.or_else(|| env::var("VISUAL").ok().or_else(|| env::var("EDITOR").ok()))
|
||||
.unwrap_or_else(|| {
|
||||
if cfg!(windows) {
|
||||
"notepad".to_string()
|
||||
} else {
|
||||
"nano".to_string()
|
||||
}
|
||||
});
|
||||
which::which(&editor).ok().map(|_| editor)
|
||||
})
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("Editor not found. Please add the `editor` configuration or set the $EDITOR or $VISUAL environment variable."))
|
||||
}
|
||||
|
||||
pub fn sync_models_url(&self) -> String {
|
||||
self.sync_models_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| super::SYNC_MODELS_URL.into())
|
||||
}
|
||||
|
||||
pub fn light_theme(&self) -> bool {
|
||||
matches!(self.theme.as_deref(), Some("light"))
|
||||
}
|
||||
|
||||
pub fn render_options(&self) -> Result<RenderOptions> {
|
||||
let theme = if self.highlight {
|
||||
let theme_mode = if self.light_theme() { "light" } else { "dark" };
|
||||
let theme_filename = format!("{theme_mode}.tmTheme");
|
||||
let theme_path = paths::local_path(&theme_filename);
|
||||
if theme_path.exists() {
|
||||
let theme = ThemeSet::get_theme(&theme_path)
|
||||
.with_context(|| format!("Invalid theme at '{}'", theme_path.display()))?;
|
||||
Some(theme)
|
||||
} else {
|
||||
let theme = if self.light_theme() {
|
||||
decode_bin(super::LIGHT_THEME).context("Invalid builtin light theme")?
|
||||
} else {
|
||||
decode_bin(super::DARK_THEME).context("Invalid builtin dark theme")?
|
||||
};
|
||||
Some(theme)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let wrap = if *IS_STDOUT_TERMINAL {
|
||||
self.wrap.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let truecolor = matches!(
|
||||
env::var("COLORTERM").as_ref().map(|v| v.as_str()),
|
||||
Ok("truecolor")
|
||||
);
|
||||
Ok(RenderOptions::new(theme, wrap, self.wrap_code, truecolor))
|
||||
}
|
||||
|
||||
pub fn print_markdown(&self, text: &str) -> Result<()> {
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
let render_options = self.render_options()?;
|
||||
let mut markdown_render = MarkdownRender::init(render_options)?;
|
||||
println!("{}", markdown_render.render(text));
|
||||
} else {
|
||||
println!("{text}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
#[allow(dead_code)]
|
||||
pub fn set_wrap(&mut self, value: &str) -> Result<()> {
|
||||
if value == "no" {
|
||||
self.wrap = None;
|
||||
} else if value == "auto" {
|
||||
self.wrap = Some(value.into());
|
||||
} else {
|
||||
value
|
||||
.parse::<u16>()
|
||||
.map_err(|_| anyhow!("Invalid wrap value"))?;
|
||||
self.wrap = Some(value.into())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn setup_document_loaders(&mut self) {
|
||||
[("pdf", "pdftotext $1 -"), ("docx", "pandoc --to plain $1")]
|
||||
.into_iter()
|
||||
.for_each(|(k, v)| {
|
||||
let (k, v) = (k.to_string(), v.to_string());
|
||||
self.document_loaders.entry(k).or_insert(v);
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn setup_user_agent(&mut self) {
|
||||
if let Some("auto") = self.user_agent.as_deref() {
|
||||
self.user_agent = Some(format!(
|
||||
"{}/{}",
|
||||
env!("CARGO_CRATE_NAME"),
|
||||
env!("CARGO_PKG_VERSION")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn load_envs(&mut self) {
|
||||
if let Ok(v) = env::var(get_env_name("model")) {
|
||||
self.model_id = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<f64>(&get_env_name("temperature")) {
|
||||
self.temperature = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<f64>(&get_env_name("top_p")) {
|
||||
self.top_p = v;
|
||||
}
|
||||
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("dry_run")) {
|
||||
self.dry_run = v;
|
||||
}
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("stream")) {
|
||||
self.stream = v;
|
||||
}
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("save")) {
|
||||
self.save = v;
|
||||
}
|
||||
if let Ok(v) = env::var(get_env_name("keybindings"))
|
||||
&& v == "vi"
|
||||
{
|
||||
self.keybindings = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("editor")) {
|
||||
self.editor = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("wrap")) {
|
||||
self.wrap = v;
|
||||
}
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("wrap_code")) {
|
||||
self.wrap_code = v;
|
||||
}
|
||||
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("function_calling_support")) {
|
||||
self.function_calling_support = v;
|
||||
}
|
||||
if let Ok(v) = env::var(get_env_name("mapping_tools"))
|
||||
&& let Ok(v) = serde_json::from_str(&v)
|
||||
{
|
||||
self.mapping_tools = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("enabled_tools")) {
|
||||
self.enabled_tools = v;
|
||||
}
|
||||
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("mcp_server_support")) {
|
||||
self.mcp_server_support = v;
|
||||
}
|
||||
if let Ok(v) = env::var(get_env_name("mapping_mcp_servers"))
|
||||
&& let Ok(v) = serde_json::from_str(&v)
|
||||
{
|
||||
self.mapping_mcp_servers = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("enabled_mcp_servers")) {
|
||||
self.enabled_mcp_servers = v;
|
||||
}
|
||||
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("repl_prelude")) {
|
||||
self.repl_prelude = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("cmd_prelude")) {
|
||||
self.cmd_prelude = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("agent_session")) {
|
||||
self.agent_session = v;
|
||||
}
|
||||
|
||||
if let Some(v) = super::read_env_bool(&get_env_name("save_session")) {
|
||||
self.save_session = v;
|
||||
}
|
||||
if let Some(Some(v)) =
|
||||
super::read_env_value::<usize>(&get_env_name("compression_threshold"))
|
||||
{
|
||||
self.compression_threshold = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("summarization_prompt")) {
|
||||
self.summarization_prompt = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("summary_context_prompt")) {
|
||||
self.summary_context_prompt = v;
|
||||
}
|
||||
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("rag_embedding_model")) {
|
||||
self.rag_embedding_model = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("rag_reranker_model")) {
|
||||
self.rag_reranker_model = v;
|
||||
}
|
||||
if let Some(Some(v)) = super::read_env_value::<usize>(&get_env_name("rag_top_k")) {
|
||||
self.rag_top_k = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<usize>(&get_env_name("rag_chunk_size")) {
|
||||
self.rag_chunk_size = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<usize>(&get_env_name("rag_chunk_overlap")) {
|
||||
self.rag_chunk_overlap = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("rag_template")) {
|
||||
self.rag_template = v;
|
||||
}
|
||||
|
||||
if let Ok(v) = env::var(get_env_name("document_loaders"))
|
||||
&& let Ok(v) = serde_json::from_str(&v)
|
||||
{
|
||||
self.document_loaders = v;
|
||||
}
|
||||
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("highlight")) {
|
||||
self.highlight = v;
|
||||
}
|
||||
if *NO_COLOR {
|
||||
self.highlight = false;
|
||||
}
|
||||
if self.highlight && self.theme.is_none() {
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("theme")) {
|
||||
self.theme = v;
|
||||
} else if *IS_STDOUT_TERMINAL
|
||||
&& let Ok(color_scheme) = color_scheme(QueryOptions::default())
|
||||
{
|
||||
let theme = match color_scheme {
|
||||
ColorScheme::Dark => "dark",
|
||||
ColorScheme::Light => "light",
|
||||
};
|
||||
self.theme = Some(theme.into());
|
||||
}
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("left_prompt")) {
|
||||
self.left_prompt = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("right_prompt")) {
|
||||
self.right_prompt = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("user_agent")) {
|
||||
self.user_agent = v;
|
||||
}
|
||||
if let Some(Some(v)) = super::read_env_bool(&get_env_name("save_shell_history")) {
|
||||
self.save_shell_history = v;
|
||||
}
|
||||
if let Some(v) = super::read_env_value::<String>(&get_env_name("sync_models_url")) {
|
||||
self.sync_models_url = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
#[allow(dead_code)]
|
||||
pub fn set_temperature_default(&mut self, value: Option<f64>) {
|
||||
self.temperature = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_top_p_default(&mut self, value: Option<f64>) {
|
||||
self.top_p = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_enabled_tools_default(&mut self, value: Option<String>) {
|
||||
self.enabled_tools = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_enabled_mcp_servers_default(&mut self, value: Option<String>) {
|
||||
self.enabled_mcp_servers = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_save_session_default(&mut self, value: Option<bool>) {
|
||||
self.save_session = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_compression_threshold_default(&mut self, value: Option<usize>) {
|
||||
self.compression_threshold = value.unwrap_or_default();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_rag_reranker_model_default(&mut self, value: Option<String>) {
|
||||
self.rag_reranker_model = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_rag_top_k_default(&mut self, value: usize) {
|
||||
self.rag_top_k = value;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn set_model_id_default(&mut self, model_id: String) {
|
||||
self.model_id = model_id;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn ensure_default_model_id(&mut self) -> Result<String> {
|
||||
if self.model_id.is_empty() {
|
||||
let models = crate::client::list_models(self, crate::client::ModelType::Chat);
|
||||
if models.is_empty() {
|
||||
anyhow::bail!("No available model");
|
||||
}
|
||||
self.model_id = models[0].id();
|
||||
}
|
||||
Ok(self.model_id.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
|
||||
fn cached_editor() -> Option<String> {
|
||||
super::super::EDITOR.get().cloned().flatten()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_app_config_copies_serialized_fields() {
|
||||
let cfg = Config {
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
dry_run: true,
|
||||
stream: false,
|
||||
save: true,
|
||||
highlight: false,
|
||||
compression_threshold: 2000,
|
||||
rag_top_k: 10,
|
||||
..Config::default()
|
||||
};
|
||||
|
||||
let app = cfg.to_app_config();
|
||||
|
||||
assert_eq!(app.model_id, "test-model");
|
||||
assert_eq!(app.temperature, Some(0.7));
|
||||
assert_eq!(app.top_p, Some(0.9));
|
||||
assert!(app.dry_run);
|
||||
assert!(!app.stream);
|
||||
assert!(app.save);
|
||||
assert!(!app.highlight);
|
||||
assert_eq!(app.compression_threshold, 2000);
|
||||
assert_eq!(app.rag_top_k, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_app_config_copies_clients() {
|
||||
let cfg = Config::default();
|
||||
let app = cfg.to_app_config();
|
||||
|
||||
assert!(app.clients.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_app_config_copies_mapping_fields() {
|
||||
let mut cfg = Config::default();
|
||||
cfg.mapping_tools
|
||||
.insert("alias".to_string(), "real_tool".to_string());
|
||||
cfg.mapping_mcp_servers
|
||||
.insert("gh".to_string(), "github-mcp".to_string());
|
||||
|
||||
let app = cfg.to_app_config();
|
||||
|
||||
assert_eq!(
|
||||
app.mapping_tools.get("alias"),
|
||||
Some(&"real_tool".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
app.mapping_mcp_servers.get("gh"),
|
||||
Some(&"github-mcp".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn editor_returns_configured_value() {
|
||||
let configured = cached_editor()
|
||||
.unwrap_or_else(|| std::env::current_exe().unwrap().display().to_string());
|
||||
let app = AppConfig {
|
||||
editor: Some(configured.clone()),
|
||||
..AppConfig::default()
|
||||
};
|
||||
|
||||
assert_eq!(app.editor().unwrap(), configured);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn editor_falls_back_to_env() {
|
||||
if let Some(expected) = cached_editor() {
|
||||
let app = AppConfig::default();
|
||||
assert_eq!(app.editor().unwrap(), expected);
|
||||
return;
|
||||
}
|
||||
|
||||
let expected = std::env::current_exe().unwrap().display().to_string();
|
||||
unsafe {
|
||||
std::env::set_var("VISUAL", &expected);
|
||||
}
|
||||
|
||||
let app = AppConfig::default();
|
||||
let result = app.editor();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_theme_default_is_false() {
|
||||
let app = AppConfig::default();
|
||||
assert!(!app.light_theme());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_models_url_has_default() {
|
||||
let app = AppConfig::default();
|
||||
let url = app.sync_models_url();
|
||||
assert!(!url.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
//! Shared global services for a running Loki process.
|
||||
//!
|
||||
//! `AppState` holds the services that are genuinely process-wide and
|
||||
//! immutable during request handling: the frozen [`AppConfig`], the
|
||||
//! credential [`Vault`](GlobalVault), the [`McpFactory`](super::mcp_factory::McpFactory)
|
||||
//! for MCP subprocess sharing, and the [`RagCache`](super::rag_cache::RagCache)
|
||||
//! for shared RAG instances. It is intended to be wrapped in `Arc`
|
||||
//! and shared across every [`RequestContext`] that a frontend (CLI,
|
||||
//! REPL, API) creates.
|
||||
//!
|
||||
//! This struct deliberately does **not** hold a live `McpRegistry`.
|
||||
//! MCP server processes are scoped to whichever `RoleLike`
|
||||
//! (role/session/agent) is currently active, because each scope may
|
||||
//! demand a different enabled server set. Live MCP processes are
|
||||
//! owned by per-scope
|
||||
//! [`ToolScope`](super::tool_scope::ToolScope)s on the
|
||||
//! [`RequestContext`] and acquired through `McpFactory`.
|
||||
//!
|
||||
//! # Phase 1 scope
|
||||
//!
|
||||
//! This is Phase 1 of the REST API refactor:
|
||||
//!
|
||||
//! * **Step 0** introduced this struct alongside the existing
|
||||
//! [`Config`](super::Config)
|
||||
//! * **Step 6.5** added the `mcp_factory` and `rag_cache` fields
|
||||
//!
|
||||
//! Neither field is wired into the runtime yet — they exist as
|
||||
//! additive scaffolding that Step 8+ will connect when the entry
|
||||
//! points migrate. See `docs/PHASE-1-IMPLEMENTATION-PLAN.md`.
|
||||
|
||||
use super::mcp_factory::McpFactory;
|
||||
use super::rag_cache::RagCache;
|
||||
use crate::config::AppConfig;
|
||||
use crate::mcp::McpServersConfig;
|
||||
use crate::vault::GlobalVault;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: Arc<AppConfig>,
|
||||
pub vault: GlobalVault,
|
||||
pub mcp_factory: Arc<McpFactory>,
|
||||
#[allow(dead_code)]
|
||||
pub rag_cache: Arc<RagCache>,
|
||||
pub mcp_config: Option<McpServersConfig>,
|
||||
pub mcp_log_path: Option<PathBuf>,
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
//! Transitional conversions between the legacy [`Config`] struct and the
|
||||
//! new [`AppConfig`] + [`RequestContext`] split.
|
||||
|
||||
use crate::config::todo::TodoList;
|
||||
|
||||
use super::{AppConfig, AppState, Config, RequestContext};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
impl Config {
|
||||
pub fn to_app_config(&self) -> AppConfig {
|
||||
AppConfig {
|
||||
model_id: self.model_id.clone(),
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
|
||||
dry_run: self.dry_run,
|
||||
stream: self.stream,
|
||||
save: self.save,
|
||||
keybindings: self.keybindings.clone(),
|
||||
editor: self.editor.clone(),
|
||||
wrap: self.wrap.clone(),
|
||||
wrap_code: self.wrap_code,
|
||||
vault_password_file: self.vault_password_file.clone(),
|
||||
|
||||
function_calling_support: self.function_calling_support,
|
||||
mapping_tools: self.mapping_tools.clone(),
|
||||
enabled_tools: self.enabled_tools.clone(),
|
||||
visible_tools: self.visible_tools.clone(),
|
||||
|
||||
mcp_server_support: self.mcp_server_support,
|
||||
mapping_mcp_servers: self.mapping_mcp_servers.clone(),
|
||||
enabled_mcp_servers: self.enabled_mcp_servers.clone(),
|
||||
|
||||
repl_prelude: self.repl_prelude.clone(),
|
||||
cmd_prelude: self.cmd_prelude.clone(),
|
||||
agent_session: self.agent_session.clone(),
|
||||
|
||||
save_session: self.save_session,
|
||||
compression_threshold: self.compression_threshold,
|
||||
summarization_prompt: self.summarization_prompt.clone(),
|
||||
summary_context_prompt: self.summary_context_prompt.clone(),
|
||||
|
||||
rag_embedding_model: self.rag_embedding_model.clone(),
|
||||
rag_reranker_model: self.rag_reranker_model.clone(),
|
||||
rag_top_k: self.rag_top_k,
|
||||
rag_chunk_size: self.rag_chunk_size,
|
||||
rag_chunk_overlap: self.rag_chunk_overlap,
|
||||
rag_template: self.rag_template.clone(),
|
||||
|
||||
document_loaders: self.document_loaders.clone(),
|
||||
|
||||
highlight: self.highlight,
|
||||
theme: self.theme.clone(),
|
||||
left_prompt: self.left_prompt.clone(),
|
||||
right_prompt: self.right_prompt.clone(),
|
||||
|
||||
user_agent: self.user_agent.clone(),
|
||||
save_shell_history: self.save_shell_history,
|
||||
sync_models_url: self.sync_models_url.clone(),
|
||||
|
||||
clients: self.clients.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_request_context(&self, app: Arc<AppState>) -> RequestContext {
|
||||
let mut mcp_runtime = super::tool_scope::McpRuntime::default();
|
||||
if let Some(registry) = &self.mcp_registry {
|
||||
mcp_runtime.sync_from_registry(registry);
|
||||
}
|
||||
let tool_tracker = self
|
||||
.tool_call_tracker
|
||||
.clone()
|
||||
.unwrap_or_else(crate::function::ToolCallTracker::default);
|
||||
RequestContext {
|
||||
app,
|
||||
macro_flag: self.macro_flag,
|
||||
info_flag: self.info_flag,
|
||||
working_mode: self.working_mode,
|
||||
model: self.model.clone(),
|
||||
agent_variables: self.agent_variables.clone(),
|
||||
role: self.role.clone(),
|
||||
session: self.session.clone(),
|
||||
rag: self.rag.clone(),
|
||||
agent: self.agent.clone(),
|
||||
last_message: self.last_message.clone(),
|
||||
tool_scope: super::tool_scope::ToolScope {
|
||||
functions: self.functions.clone(),
|
||||
mcp_runtime,
|
||||
tool_tracker,
|
||||
},
|
||||
supervisor: self.supervisor.clone(),
|
||||
parent_supervisor: self.parent_supervisor.clone(),
|
||||
self_agent_id: self.self_agent_id.clone(),
|
||||
inbox: self.inbox.clone(),
|
||||
escalation_queue: self.root_escalation_queue.clone(),
|
||||
current_depth: self.current_depth,
|
||||
auto_continue_count: 0,
|
||||
todo_list: TodoList::default(),
|
||||
last_continuation_response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
+57
-28
@@ -9,7 +9,7 @@ use crate::utils::{AbortSignal, base64_encode, is_loader_protocol, sha256};
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use indexmap::IndexSet;
|
||||
use std::{collections::HashMap, fs::File, io::Read};
|
||||
use std::{collections::HashMap, fs::File, io::Read, sync::Arc};
|
||||
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
||||
|
||||
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
||||
@@ -17,7 +17,11 @@ const SUMMARY_MAX_WIDTH: usize = 80;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Input {
|
||||
config: GlobalConfig,
|
||||
app_config: Arc<AppConfig>,
|
||||
stream_enabled: bool,
|
||||
session: Option<Session>,
|
||||
rag: Option<Arc<Rag>>,
|
||||
functions: Option<Vec<FunctionDeclaration>>,
|
||||
text: String,
|
||||
raw: (String, Vec<String>),
|
||||
patched_text: Option<String>,
|
||||
@@ -34,10 +38,15 @@ pub struct Input {
|
||||
}
|
||||
|
||||
impl Input {
|
||||
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
pub fn from_str(ctx: &RequestContext, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(ctx, role);
|
||||
let captured = capture_input_config(ctx, &role);
|
||||
Self {
|
||||
config: config.clone(),
|
||||
app_config: Arc::clone(&ctx.app.config),
|
||||
stream_enabled: captured.stream_enabled,
|
||||
session: captured.session,
|
||||
rag: captured.rag,
|
||||
functions: captured.functions,
|
||||
text: text.to_string(),
|
||||
raw: (text.to_string(), vec![]),
|
||||
patched_text: None,
|
||||
@@ -55,12 +64,12 @@ impl Input {
|
||||
}
|
||||
|
||||
pub async fn from_files(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
) -> Result<Self> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = ctx.app.config.document_loaders.clone();
|
||||
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
||||
resolve_paths(&loaders, paths)?;
|
||||
let mut last_reply = None;
|
||||
@@ -78,7 +87,7 @@ impl Input {
|
||||
texts.push(raw_text.to_string());
|
||||
};
|
||||
if with_last_reply {
|
||||
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
||||
if let Some(LastMessage { input, output, .. }) = ctx.last_message.as_ref() {
|
||||
if !output.is_empty() {
|
||||
last_reply = Some(output.clone())
|
||||
} else if let Some(v) = input.last_reply.as_ref() {
|
||||
@@ -102,9 +111,14 @@ impl Input {
|
||||
));
|
||||
}
|
||||
}
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
let (role, with_session, with_agent) = resolve_role(ctx, role);
|
||||
let captured = capture_input_config(ctx, &role);
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
app_config: Arc::clone(&ctx.app.config),
|
||||
stream_enabled: captured.stream_enabled,
|
||||
session: captured.session,
|
||||
rag: captured.rag,
|
||||
functions: captured.functions,
|
||||
text: texts.join("\n"),
|
||||
raw: (raw_text.to_string(), raw_paths),
|
||||
patched_text: None,
|
||||
@@ -122,14 +136,14 @@ impl Input {
|
||||
}
|
||||
|
||||
pub async fn from_files_with_spinner(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
abortable_run_with_spinner(
|
||||
Input::from_files(config, raw_text, paths, role),
|
||||
Input::from_files(ctx, raw_text, paths, role),
|
||||
"Loading files",
|
||||
abort_signal,
|
||||
)
|
||||
@@ -164,7 +178,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> bool {
|
||||
self.config.read().stream && !self.role().model().no_stream()
|
||||
self.stream_enabled && !self.role().model().no_stream()
|
||||
}
|
||||
|
||||
pub fn continue_output(&self) -> Option<&str> {
|
||||
@@ -183,10 +197,9 @@ impl Input {
|
||||
self.regenerate
|
||||
}
|
||||
|
||||
pub fn set_regenerate(&mut self) {
|
||||
let role = self.config.read().extract_role();
|
||||
if role.name() == self.role().name() {
|
||||
self.role = role;
|
||||
pub fn set_regenerate(&mut self, current_role: Role) {
|
||||
if current_role.name() == self.role().name() {
|
||||
self.role = current_role;
|
||||
}
|
||||
self.regenerate = true;
|
||||
self.tool_calls = None;
|
||||
@@ -196,9 +209,9 @@ impl Input {
|
||||
if self.text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rag = self.config.read().rag.clone();
|
||||
if let Some(rag) = rag {
|
||||
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
||||
if let Some(rag) = &self.rag {
|
||||
let result =
|
||||
Config::search_rag(&self.app_config, rag, &self.text, abort_signal).await?;
|
||||
self.patched_text = Some(result);
|
||||
self.rag_name = Some(rag.name().to_string());
|
||||
}
|
||||
@@ -220,7 +233,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.config, Some(self.role().model().clone()))
|
||||
init_client(&self.app_config, self.role().model().clone())
|
||||
}
|
||||
|
||||
pub async fn fetch_chat_text(&self) -> Result<String> {
|
||||
@@ -240,7 +253,7 @@ impl Input {
|
||||
model.guard_max_input_tokens(&messages)?;
|
||||
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
||||
let functions = if model.supports_function_calling() {
|
||||
let fns = self.config.read().select_functions(self.role());
|
||||
let fns = self.functions.clone();
|
||||
if let Some(vec) = &fns {
|
||||
for def in vec {
|
||||
debug!("Function definition: {:?}", def.name);
|
||||
@@ -260,7 +273,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
||||
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
||||
let mut messages = if let Some(session) = self.session(&self.session) {
|
||||
session.build_messages(self)
|
||||
} else {
|
||||
self.role().build_messages(self)
|
||||
@@ -275,7 +288,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self) -> String {
|
||||
if let Some(session) = self.session(&self.config.read().session) {
|
||||
if let Some(session) = self.session(&self.session) {
|
||||
session.echo_messages(self)
|
||||
} else {
|
||||
self.role().echo_messages(self)
|
||||
@@ -384,17 +397,33 @@ impl Input {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
||||
fn resolve_role(ctx: &RequestContext, role: Option<Role>) -> (Role, bool, bool) {
|
||||
match role {
|
||||
Some(v) => (v, false, false),
|
||||
None => (
|
||||
config.extract_role(),
|
||||
config.session.is_some(),
|
||||
config.agent.is_some(),
|
||||
ctx.extract_role(ctx.app.config.as_ref()),
|
||||
ctx.session.is_some(),
|
||||
ctx.agent.is_some(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
struct CapturedInputConfig {
|
||||
stream_enabled: bool,
|
||||
session: Option<Session>,
|
||||
rag: Option<Arc<Rag>>,
|
||||
functions: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
fn capture_input_config(ctx: &RequestContext, role: &Role) -> CapturedInputConfig {
|
||||
CapturedInputConfig {
|
||||
stream_enabled: ctx.app.config.stream,
|
||||
session: ctx.session.clone(),
|
||||
rag: ctx.rag.clone(),
|
||||
functions: ctx.select_functions(role),
|
||||
}
|
||||
}
|
||||
|
||||
type ResolvePathsOutput = (
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
|
||||
+38
-22
@@ -1,14 +1,13 @@
|
||||
use crate::config::{Config, GlobalConfig, RoleLike, ensure_parent_exists};
|
||||
use crate::config::paths;
|
||||
use crate::config::{Config, RequestContext, RoleLike, ensure_parent_exists};
|
||||
use crate::repl::{run_repl_command, split_args_text};
|
||||
use crate::utils::{AbortSignal, multiline_text};
|
||||
use anyhow::{Result, anyhow};
|
||||
use indexmap::IndexMap;
|
||||
use parking_lot::RwLock;
|
||||
use rust_embed::Embed;
|
||||
use serde::Deserialize;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "assets/macros"]
|
||||
@@ -16,7 +15,7 @@ struct MacroAssets;
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
pub async fn macro_execute(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
name: &str,
|
||||
args: Option<&str>,
|
||||
abort_signal: AbortSignal,
|
||||
@@ -29,25 +28,42 @@ pub async fn macro_execute(
|
||||
let variables = macro_value
|
||||
.resolve_variables(&new_args)
|
||||
.map_err(|err| anyhow!("{err}. Usage: {}", macro_value.usage(name)))?;
|
||||
let role = config.read().extract_role();
|
||||
let mut config = config.read().clone();
|
||||
config.temperature = role.temperature();
|
||||
config.top_p = role.top_p();
|
||||
config.enabled_tools = role.enabled_tools().clone();
|
||||
config.enabled_mcp_servers = role.enabled_mcp_servers().clone();
|
||||
config.macro_flag = true;
|
||||
config.model = role.model().clone();
|
||||
config.role = None;
|
||||
config.session = None;
|
||||
config.rag = None;
|
||||
config.agent = None;
|
||||
config.discontinuous_last_message();
|
||||
let config = Arc::new(RwLock::new(config));
|
||||
config.write().macro_flag = true;
|
||||
let role = ctx.extract_role(ctx.app.config.as_ref());
|
||||
let mut app_config = (*ctx.app.config).clone();
|
||||
app_config.temperature = role.temperature();
|
||||
app_config.top_p = role.top_p();
|
||||
app_config.enabled_tools = role.enabled_tools().clone();
|
||||
app_config.enabled_mcp_servers = role.enabled_mcp_servers().clone();
|
||||
|
||||
let mut app_state = (*ctx.app).clone();
|
||||
app_state.config = std::sync::Arc::new(app_config);
|
||||
|
||||
let mut macro_ctx = RequestContext::new(std::sync::Arc::new(app_state), ctx.working_mode);
|
||||
macro_ctx.macro_flag = true;
|
||||
macro_ctx.info_flag = ctx.info_flag;
|
||||
macro_ctx.model = role.model().clone();
|
||||
macro_ctx.agent_variables = ctx.agent_variables.clone();
|
||||
macro_ctx.last_message = ctx.last_message.clone();
|
||||
macro_ctx.supervisor = ctx.supervisor.clone();
|
||||
macro_ctx.parent_supervisor = ctx.parent_supervisor.clone();
|
||||
macro_ctx.self_agent_id = ctx.self_agent_id.clone();
|
||||
macro_ctx.inbox = ctx.inbox.clone();
|
||||
macro_ctx.escalation_queue = ctx.escalation_queue.clone();
|
||||
macro_ctx.current_depth = ctx.current_depth;
|
||||
macro_ctx.auto_continue_count = ctx.auto_continue_count;
|
||||
macro_ctx.todo_list = ctx.todo_list.clone();
|
||||
macro_ctx.tool_scope.tool_tracker = ctx.tool_scope.tool_tracker.clone();
|
||||
macro_ctx.discontinuous_last_message();
|
||||
|
||||
let app = macro_ctx.app.config.clone();
|
||||
macro_ctx
|
||||
.bootstrap_tools(app.as_ref(), true, abort_signal.clone())
|
||||
.await?;
|
||||
|
||||
for step in ¯o_value.steps {
|
||||
let command = Macro::interpolate_command(step, &variables);
|
||||
println!(">> {}", multiline_text(&command));
|
||||
run_repl_command(&config, abort_signal.clone(), &command).await?;
|
||||
run_repl_command(&mut macro_ctx, abort_signal.clone(), &command).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -63,7 +79,7 @@ impl Macro {
|
||||
pub fn install_macros() -> Result<()> {
|
||||
info!(
|
||||
"Installing built-in macros in {}",
|
||||
Config::macros_dir().display()
|
||||
paths::macros_dir().display()
|
||||
);
|
||||
|
||||
for file in MacroAssets::iter() {
|
||||
@@ -71,7 +87,7 @@ impl Macro {
|
||||
let embedded_file = MacroAssets::get(&file)
|
||||
.ok_or_else(|| anyhow!("Failed to load embedded macro file: {}", file.as_ref()))?;
|
||||
let content = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let file_path = Config::macros_dir().join(file.as_ref());
|
||||
let file_path = paths::macros_dir().join(file.as_ref());
|
||||
|
||||
if file_path.exists() {
|
||||
debug!(
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
//! Per-process factory for MCP subprocess handles.
|
||||
//!
|
||||
//! `McpFactory` lives on [`AppState`](super::AppState) and is the
|
||||
//! single entrypoint that scopes use to obtain `Arc<ConnectedServer>`
|
||||
//! handles for MCP tool servers. Multiple scopes requesting the same
|
||||
//! server can (eventually) share a single subprocess via `Arc`
|
||||
//! reference counting.
|
||||
//!
|
||||
//! # Phase 1 Step 6.5 scope
|
||||
//!
|
||||
//! This file introduces the factory scaffolding with a trivial
|
||||
//! implementation:
|
||||
//!
|
||||
//! * `active` — `Mutex<HashMap<McpServerKey, Weak<ConnectedServer>>>`
|
||||
//! for future Arc-based sharing across scopes
|
||||
//! * `acquire` — unimplemented stub for now; will be filled in when
|
||||
//! Step 8 rewrites `use_role` / `use_session` / `use_agent` to
|
||||
//! actually build `ToolScope`s
|
||||
//!
|
||||
//! The full design (idle pool, reaper task, per-server TTL, health
|
||||
//! checks, graceful shutdown) lands in **Phase 5** per
|
||||
//! `docs/PHASE-5-IMPLEMENTATION-PLAN.md`. Phase 1 Step 6.5 ships just
|
||||
//! enough for the type to exist on `AppState` and participate in
|
||||
//! construction / test round-trips.
|
||||
//!
|
||||
//! The key type `McpServerKey` hashes the server name plus its full
|
||||
//! command/args/env so that two scopes requesting an identically-
|
||||
//! configured server share an `Arc`, while two scopes requesting
|
||||
//! differently-configured servers (e.g., different API tokens) get
|
||||
//! independent subprocesses. This is the sharing-vs-isolation property
|
||||
//! described in `docs/REST-API-ARCHITECTURE.md` section 5.
|
||||
|
||||
use crate::mcp::{ConnectedServer, JsonField, McpServer, spawn_mcp_server};
|
||||
|
||||
use anyhow::Result;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub struct McpServerKey {
|
||||
pub name: String,
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl McpServerKey {
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
command: impl Into<String>,
|
||||
args: impl IntoIterator<Item = String>,
|
||||
env: impl IntoIterator<Item = (String, String)>,
|
||||
) -> Self {
|
||||
let mut args: Vec<String> = args.into_iter().collect();
|
||||
args.sort();
|
||||
let mut env: Vec<(String, String)> = env.into_iter().collect();
|
||||
env.sort();
|
||||
Self {
|
||||
name: name.into(),
|
||||
command: command.into(),
|
||||
args,
|
||||
env,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_spec(name: &str, spec: &McpServer) -> Self {
|
||||
let args = spec.args.clone().unwrap_or_default();
|
||||
let env: Vec<(String, String)> = spec
|
||||
.env
|
||||
.as_ref()
|
||||
.map(|e| {
|
||||
e.iter()
|
||||
.map(|(k, v)| {
|
||||
let v_str = match v {
|
||||
JsonField::Str(s) => s.clone(),
|
||||
JsonField::Bool(b) => b.to_string(),
|
||||
JsonField::Int(i) => i.to_string(),
|
||||
};
|
||||
(k.clone(), v_str)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Self::new(name, &spec.command, args, env)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct McpFactory {
|
||||
active: Mutex<HashMap<McpServerKey, Weak<ConnectedServer>>>,
|
||||
}
|
||||
|
||||
impl McpFactory {
|
||||
pub fn try_get_active(&self, key: &McpServerKey) -> Option<Arc<ConnectedServer>> {
|
||||
let map = self.active.lock();
|
||||
map.get(key).and_then(|weak| weak.upgrade())
|
||||
}
|
||||
|
||||
pub fn insert_active(&self, key: McpServerKey, handle: &Arc<ConnectedServer>) {
|
||||
let mut map = self.active.lock();
|
||||
map.insert(key, Arc::downgrade(handle));
|
||||
}
|
||||
|
||||
pub async fn acquire(
|
||||
&self,
|
||||
name: &str,
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
let key = McpServerKey::from_spec(name, spec);
|
||||
|
||||
if let Some(existing) = self.try_get_active(&key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let handle = spawn_mcp_server(spec, log_path).await?;
|
||||
self.insert_active(key, &handle);
|
||||
Ok(handle)
|
||||
}
|
||||
}
|
||||
+75
-2283
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,265 @@
|
||||
//! Static path and filesystem-lookup helpers that used to live as
|
||||
//! associated functions on [`Config`](super::Config).
|
||||
//!
|
||||
//! None of these functions depend on any `Config` instance data — they
|
||||
//! compute paths from environment variables, XDG directories, or the
|
||||
//! crate constant for the config root. Moving them here is Phase 1
|
||||
//! Step 2 of the REST API refactor: the `Config` struct is shedding
|
||||
//! anything that doesn't actually need per-instance state so the
|
||||
//! eventual split into `AppConfig` + `RequestContext` has a clean
|
||||
//! division line.
|
||||
//!
|
||||
//! # Compatibility shim during migration
|
||||
//!
|
||||
//! The existing associated functions on `Config` (e.g.,
|
||||
//! `Config::config_dir()`) are kept as `#[deprecated]` forwarders that
|
||||
//! call into this module. Callers are migrated module-by-module; when
|
||||
//! the last caller is updated, the forwarders are deleted in a later
|
||||
//! sub-step of Step 2.
|
||||
|
||||
use super::role::Role;
|
||||
use super::{
|
||||
AGENTS_DIR_NAME, BASH_PROMPT_UTILS_FILE_NAME, CONFIG_FILE_NAME, ENV_FILE_NAME,
|
||||
FUNCTIONS_BIN_DIR_NAME, FUNCTIONS_DIR_NAME, GLOBAL_TOOLS_DIR_NAME, GLOBAL_TOOLS_UTILS_DIR_NAME,
|
||||
MACROS_DIR_NAME, MCP_FILE_NAME, ModelsOverride, RAGS_DIR_NAME, ROLES_DIR_NAME,
|
||||
};
|
||||
use crate::client::ProviderModels;
|
||||
use crate::utils::{get_env_name, list_file_names, normalize_env_name};
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use log::LevelFilter;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::fs::{read_dir, read_to_string};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn config_dir() -> PathBuf {
|
||||
if let Ok(v) = env::var(get_env_name("config_dir")) {
|
||||
PathBuf::from(v)
|
||||
} else if let Ok(v) = env::var("XDG_CONFIG_HOME") {
|
||||
PathBuf::from(v).join(env!("CARGO_CRATE_NAME"))
|
||||
} else {
|
||||
let dir = dirs::config_dir().expect("No user's config directory");
|
||||
dir.join(env!("CARGO_CRATE_NAME"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_path(name: &str) -> PathBuf {
|
||||
config_dir().join(name)
|
||||
}
|
||||
|
||||
pub fn cache_path() -> PathBuf {
|
||||
let base_dir = dirs::cache_dir().unwrap_or_else(env::temp_dir);
|
||||
base_dir.join(env!("CARGO_CRATE_NAME"))
|
||||
}
|
||||
|
||||
pub fn oauth_tokens_path() -> PathBuf {
|
||||
cache_path().join("oauth")
|
||||
}
|
||||
|
||||
pub fn token_file(client_name: &str) -> PathBuf {
|
||||
oauth_tokens_path().join(format!("{client_name}_oauth_tokens.json"))
|
||||
}
|
||||
|
||||
pub fn log_path() -> PathBuf {
|
||||
cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME")))
|
||||
}
|
||||
|
||||
pub fn config_file() -> PathBuf {
|
||||
match env::var(get_env_name("config_file")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(CONFIG_FILE_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn roles_dir() -> PathBuf {
|
||||
match env::var(get_env_name("roles_dir")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(ROLES_DIR_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role_file(name: &str) -> PathBuf {
|
||||
roles_dir().join(format!("{name}.md"))
|
||||
}
|
||||
|
||||
pub fn macros_dir() -> PathBuf {
|
||||
match env::var(get_env_name("macros_dir")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(MACROS_DIR_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn macro_file(name: &str) -> PathBuf {
|
||||
macros_dir().join(format!("{name}.yaml"))
|
||||
}
|
||||
|
||||
pub fn env_file() -> PathBuf {
|
||||
match env::var(get_env_name("env_file")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(ENV_FILE_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rags_dir() -> PathBuf {
|
||||
match env::var(get_env_name("rags_dir")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(RAGS_DIR_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn functions_dir() -> PathBuf {
|
||||
match env::var(get_env_name("functions_dir")) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => local_path(FUNCTIONS_DIR_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn functions_bin_dir() -> PathBuf {
|
||||
functions_dir().join(FUNCTIONS_BIN_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn mcp_config_file() -> PathBuf {
|
||||
functions_dir().join(MCP_FILE_NAME)
|
||||
}
|
||||
|
||||
pub fn global_tools_dir() -> PathBuf {
|
||||
functions_dir().join(GLOBAL_TOOLS_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn global_utils_dir() -> PathBuf {
|
||||
functions_dir().join(GLOBAL_TOOLS_UTILS_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn bash_prompt_utils_file() -> PathBuf {
|
||||
global_utils_dir().join(BASH_PROMPT_UTILS_FILE_NAME)
|
||||
}
|
||||
|
||||
pub fn agents_data_dir() -> PathBuf {
|
||||
local_path(AGENTS_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn agent_data_dir(name: &str) -> PathBuf {
|
||||
match env::var(format!("{}_DATA_DIR", normalize_env_name(name))) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => agents_data_dir().join(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn agent_config_file(name: &str) -> PathBuf {
|
||||
match env::var(format!("{}_CONFIG_FILE", normalize_env_name(name))) {
|
||||
Ok(value) => PathBuf::from(value),
|
||||
Err(_) => agent_data_dir(name).join(CONFIG_FILE_NAME),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn agent_bin_dir(name: &str) -> PathBuf {
|
||||
agent_data_dir(name).join(FUNCTIONS_BIN_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> PathBuf {
|
||||
agent_data_dir(agent_name).join(format!("{rag_name}.yaml"))
|
||||
}
|
||||
|
||||
pub fn agent_functions_file(name: &str) -> Result<PathBuf> {
|
||||
let priority = ["tools.sh", "tools.py", "tools.ts", "tools.js"];
|
||||
let dir = agent_data_dir(name);
|
||||
|
||||
for filename in priority {
|
||||
let path = dir.join(filename);
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"No tools script found in agent functions directory"
|
||||
))
|
||||
}
|
||||
|
||||
pub fn models_override_file() -> PathBuf {
|
||||
local_path("models-override.yaml")
|
||||
}
|
||||
|
||||
pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> {
|
||||
let log_level = env::var(get_env_name("log_level"))
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(match cfg!(debug_assertions) {
|
||||
true => LevelFilter::Debug,
|
||||
false => LevelFilter::Info,
|
||||
});
|
||||
let resolved_log_path = match env::var(get_env_name("log_path")) {
|
||||
Ok(v) => Some(PathBuf::from(v)),
|
||||
Err(_) => Some(log_path()),
|
||||
};
|
||||
Ok((log_level, resolved_log_path))
|
||||
}
|
||||
|
||||
pub fn list_roles(with_builtin: bool) -> Vec<String> {
|
||||
let mut names = HashSet::new();
|
||||
if let Ok(rd) = read_dir(roles_dir()) {
|
||||
for entry in rd.flatten() {
|
||||
if let Some(name) = entry
|
||||
.file_name()
|
||||
.to_str()
|
||||
.and_then(|v| v.strip_suffix(".md"))
|
||||
{
|
||||
names.insert(name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
if with_builtin {
|
||||
names.extend(Role::list_builtin_role_names());
|
||||
}
|
||||
let mut names: Vec<_> = names.into_iter().collect();
|
||||
names.sort_unstable();
|
||||
names
|
||||
}
|
||||
|
||||
pub fn has_role(name: &str) -> bool {
|
||||
let names = list_roles(true);
|
||||
names.contains(&name.to_string())
|
||||
}
|
||||
|
||||
pub fn list_rags() -> Vec<String> {
|
||||
match read_dir(rags_dir()) {
|
||||
Ok(rd) => {
|
||||
let mut names = vec![];
|
||||
for entry in rd.flatten() {
|
||||
let name = entry.file_name();
|
||||
if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") {
|
||||
names.push(name.to_string());
|
||||
}
|
||||
}
|
||||
names.sort_unstable();
|
||||
names
|
||||
}
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_macros() -> Vec<String> {
|
||||
list_file_names(macros_dir(), ".yaml")
|
||||
}
|
||||
|
||||
pub fn has_macro(name: &str) -> bool {
|
||||
let names = list_macros();
|
||||
names.contains(&name.to_string())
|
||||
}
|
||||
|
||||
pub fn local_models_override() -> Result<Vec<ProviderModels>> {
|
||||
let model_override_path = models_override_file();
|
||||
let err = || {
|
||||
format!(
|
||||
"Failed to load models at '{}'",
|
||||
model_override_path.display()
|
||||
)
|
||||
};
|
||||
let content = read_to_string(&model_override_path).with_context(err)?;
|
||||
let models_override: ModelsOverride = serde_yaml::from_str(&content).with_context(err)?;
|
||||
if models_override.version != env!("CARGO_PKG_VERSION") {
|
||||
bail!("Incompatible version")
|
||||
}
|
||||
Ok(models_override.list)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
//! Per-process RAG instance cache with weak-reference sharing.
|
||||
//!
|
||||
//! `RagCache` lives on [`AppState`](super::AppState) and serves both
|
||||
//! standalone RAGs (attached via `.rag <name>`) and agent-owned RAGs
|
||||
//! (loaded from an agent's `documents:` field). The cache keys with
|
||||
//! [`RagKey`] so that agent RAGs and standalone RAGs occupy distinct
|
||||
//! namespaces even if they share a name.
|
||||
//!
|
||||
//! Entries are held as `Weak<Rag>` so the cache never keeps a RAG
|
||||
//! alive on its own — once all active scopes drop their `Arc<Rag>`,
|
||||
//! the cache entry becomes unupgradable and the next `load()` falls
|
||||
//! through to a fresh disk read.
|
||||
//!
|
||||
//! # Phase 1 Step 6.5 scope
|
||||
//!
|
||||
//! This file introduces the type scaffolding. Actual cache population
|
||||
//! (i.e., routing `use_rag`, `use_agent`, and sub-agent spawning
|
||||
//! through the cache) is deferred to Step 8 when the entry points get
|
||||
//! rewritten. During the bridge window, `Config.rag` keeps serving
|
||||
//! today's callers via direct `Rag::load` / `Rag::init` calls and
|
||||
//! `RagCache` sits on `AppState` as an unused-but-ready service.
|
||||
//!
|
||||
//! See `docs/REST-API-ARCHITECTURE.md` section 5 ("RAG Cache") for
|
||||
//! the full design including concurrent first-load serialization and
|
||||
//! invalidation semantics.
|
||||
|
||||
use crate::rag::Rag;
|
||||
|
||||
use anyhow::Result;
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub enum RagKey {
|
||||
Named(String),
|
||||
Agent(String),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct RagCache {
|
||||
entries: RwLock<HashMap<RagKey, Weak<Rag>>>,
|
||||
}
|
||||
|
||||
impl RagCache {
|
||||
pub fn try_get(&self, key: &RagKey) -> Option<Arc<Rag>> {
|
||||
let map = self.entries.read();
|
||||
map.get(key).and_then(|weak| weak.upgrade())
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: RagKey, rag: &Arc<Rag>) {
|
||||
let mut map = self.entries.write();
|
||||
map.insert(key, Arc::downgrade(rag));
|
||||
}
|
||||
|
||||
pub fn invalidate(&self, key: &RagKey) {
|
||||
let mut map = self.entries.write();
|
||||
map.remove(key);
|
||||
}
|
||||
|
||||
pub async fn load_with<F, Fut>(&self, key: RagKey, loader: F) -> Result<Arc<Rag>>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<Rag>>,
|
||||
{
|
||||
if let Some(existing) = self.try_get(&key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
let rag = loader().await?;
|
||||
let arc = Arc::new(rag);
|
||||
self.insert(key, &arc);
|
||||
Ok(arc)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -374,6 +374,100 @@ fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn role_new_parses_prompt() {
|
||||
let role = Role::new("test", "You are a helpful assistant");
|
||||
assert_eq!(role.name(), "test");
|
||||
assert_eq!(role.prompt(), "You are a helpful assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_new_parses_metadata() {
|
||||
let content =
|
||||
"---\nmodel: openai:gpt-4\ntemperature: 0.7\ntop_p: 0.9\n---\nYou are helpful";
|
||||
let role = Role::new("test", content);
|
||||
assert_eq!(role.model_id(), Some("openai:gpt-4"));
|
||||
assert_eq!(role.temperature(), Some(0.7));
|
||||
assert_eq!(role.top_p(), Some(0.9));
|
||||
assert_eq!(role.prompt(), "You are helpful");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_new_parses_enabled_tools() {
|
||||
let content = "---\nenabled_tools: tool1,tool2\n---\nPrompt";
|
||||
let role = Role::new("test", content);
|
||||
assert_eq!(role.enabled_tools(), Some("tool1,tool2".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_new_parses_enabled_mcp_servers() {
|
||||
let content = "---\nenabled_mcp_servers: github,jira\n---\nPrompt";
|
||||
let role = Role::new("test", content);
|
||||
assert_eq!(role.enabled_mcp_servers(), Some("github,jira".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_new_no_metadata_has_none_fields() {
|
||||
let role = Role::new("test", "Just a prompt");
|
||||
assert_eq!(role.model_id(), None);
|
||||
assert_eq!(role.temperature(), None);
|
||||
assert_eq!(role.top_p(), None);
|
||||
assert_eq!(role.enabled_tools(), None);
|
||||
assert_eq!(role.enabled_mcp_servers(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_builtin_shell_loads() {
|
||||
let role = Role::builtin("shell").unwrap();
|
||||
assert_eq!(role.name(), "shell");
|
||||
assert!(!role.prompt().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_builtin_code_loads() {
|
||||
let role = Role::builtin("code").unwrap();
|
||||
assert_eq!(role.name(), "code");
|
||||
assert!(!role.prompt().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_builtin_nonexistent_errors() {
|
||||
let result = Role::builtin("nonexistent_role_xyz");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_default_has_empty_fields() {
|
||||
let role = Role::default();
|
||||
assert_eq!(role.name(), "");
|
||||
assert_eq!(role.prompt(), "");
|
||||
assert_eq!(role.model_id(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_set_model_updates_model() {
|
||||
let mut role = Role::new("test", "prompt");
|
||||
let model = Model::default();
|
||||
role.set_model(model.clone());
|
||||
assert_eq!(role.model().id(), model.id());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_set_temperature_works() {
|
||||
let mut role = Role::new("test", "prompt");
|
||||
role.set_temperature(Some(0.5));
|
||||
assert_eq!(role.temperature(), Some(0.5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn role_export_includes_metadata() {
|
||||
let content = "---\ntemperature: 0.8\n---\nMy prompt";
|
||||
let role = Role::new("test", content);
|
||||
let exported = role.export();
|
||||
assert!(exported.contains("temperature"));
|
||||
assert!(exported.contains("My prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt1() {
|
||||
let prompt = r#"
|
||||
|
||||
+180
-6
@@ -67,11 +67,11 @@ pub struct Session {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(config: &Config, name: &str) -> Self {
|
||||
let role = config.extract_role();
|
||||
pub fn new_from_ctx(ctx: &RequestContext, app: &AppConfig, name: &str) -> Self {
|
||||
let role = ctx.extract_role(app);
|
||||
let mut session = Self {
|
||||
name: name.to_string(),
|
||||
save_session: config.save_session,
|
||||
save_session: app.save_session,
|
||||
..Default::default()
|
||||
};
|
||||
session.set_role(role);
|
||||
@@ -79,13 +79,18 @@ impl Session {
|
||||
session
|
||||
}
|
||||
|
||||
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
|
||||
pub fn load_from_ctx(
|
||||
ctx: &RequestContext,
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
path: &Path,
|
||||
) -> Result<Self> {
|
||||
let content = read_to_string(path)
|
||||
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
|
||||
let mut session: Self =
|
||||
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?;
|
||||
|
||||
session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?;
|
||||
session.model = Model::retrieve_model(app, &session.model_id, ModelType::Chat)?;
|
||||
|
||||
if let Some(autoname) = name.strip_prefix("_/") {
|
||||
session.name = TEMP_SESSION_NAME.to_string();
|
||||
@@ -99,7 +104,7 @@ impl Session {
|
||||
}
|
||||
|
||||
if let Some(role_name) = &session.role_name
|
||||
&& let Ok(role) = config.retrieve_role(role_name)
|
||||
&& let Ok(role) = ctx.retrieve_role(app, role_name)
|
||||
{
|
||||
session.role_prompt = role.prompt().to_string();
|
||||
}
|
||||
@@ -664,3 +669,172 @@ impl AutoName {
|
||||
!self.naming && self.chat_history.is_some() && self.name.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client::{Message, MessageContent, MessageRole, Model};
|
||||
use crate::config::{AppState, Config};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn session_default_is_empty() {
|
||||
let session = Session::default();
|
||||
assert!(session.is_empty());
|
||||
assert_eq!(session.name(), "");
|
||||
assert_eq!(session.role_name(), None);
|
||||
assert!(!session.dirty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_new_from_ctx_captures_save_session() {
|
||||
let cfg = Config::default();
|
||||
let app_config = Arc::new(cfg.to_app_config());
|
||||
let app_state = Arc::new(AppState {
|
||||
config: app_config.clone(),
|
||||
vault: cfg.vault.clone(),
|
||||
mcp_factory: Arc::new(mcp_factory::McpFactory::default()),
|
||||
rag_cache: Arc::new(rag_cache::RagCache::default()),
|
||||
mcp_config: None,
|
||||
mcp_log_path: None,
|
||||
});
|
||||
let ctx = cfg.to_request_context(app_state);
|
||||
let session = Session::new_from_ctx(&ctx, &app_config, "test-session");
|
||||
|
||||
assert_eq!(session.name(), "test-session");
|
||||
assert_eq!(session.save_session(), app_config.save_session);
|
||||
assert!(session.is_empty());
|
||||
assert!(!session.dirty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_set_role_captures_role_info() {
|
||||
let mut session = Session::default();
|
||||
let content = "---\ntemperature: 0.5\n---\nYou are a coder";
|
||||
let mut role = Role::new("coder", content);
|
||||
role.set_model(Model::default());
|
||||
|
||||
session.set_role(role);
|
||||
|
||||
assert_eq!(session.role_name(), Some("coder"));
|
||||
assert_eq!(session.temperature(), Some(0.5));
|
||||
assert!(session.dirty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_clear_role() {
|
||||
let mut session = Session::default();
|
||||
let mut role = Role::new("test", "prompt");
|
||||
role.set_model(Model::default());
|
||||
session.set_role(role);
|
||||
|
||||
assert_eq!(session.role_name(), Some("test"));
|
||||
|
||||
session.clear_role();
|
||||
|
||||
assert_eq!(session.role_name(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_guard_empty_passes_when_empty() {
|
||||
let session = Session::default();
|
||||
assert!(session.guard_empty().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_needs_compression_threshold() {
|
||||
let session = Session::default();
|
||||
assert!(!session.needs_compression(4000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_needs_compression_returns_false_when_compressing() {
|
||||
let mut session = Session::default();
|
||||
session.set_compressing(true);
|
||||
assert!(!session.needs_compression(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_needs_compression_returns_false_when_threshold_zero() {
|
||||
let session = Session::default();
|
||||
assert!(!session.needs_compression(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_set_compressing_flag() {
|
||||
let mut session = Session::default();
|
||||
assert!(!session.compressing());
|
||||
session.set_compressing(true);
|
||||
assert!(session.compressing());
|
||||
session.set_compressing(false);
|
||||
assert!(!session.compressing());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_set_save_session_this_time() {
|
||||
let mut session = Session::default();
|
||||
assert!(!session.save_session_this_time);
|
||||
session.set_save_session_this_time();
|
||||
assert!(session.save_session_this_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_save_session_returns_configured_value() {
|
||||
let mut session = Session::default();
|
||||
assert_eq!(session.save_session(), None);
|
||||
session.set_save_session(Some(true));
|
||||
assert_eq!(session.save_session(), Some(true));
|
||||
session.set_save_session(Some(false));
|
||||
assert_eq!(session.save_session(), Some(false));
|
||||
session.set_save_session(None);
|
||||
assert_eq!(session.save_session(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_compress_moves_messages() {
|
||||
let mut session = Session::default();
|
||||
session.messages.push(Message::new(
|
||||
MessageRole::System,
|
||||
MessageContent::Text("system prompt".to_string()),
|
||||
));
|
||||
session.messages.push(Message::new(
|
||||
MessageRole::User,
|
||||
MessageContent::Text("hello".to_string()),
|
||||
));
|
||||
|
||||
assert_eq!(session.messages.len(), 2);
|
||||
assert!(session.compressed_messages.is_empty());
|
||||
|
||||
session.compress("Summary of conversation".to_string());
|
||||
|
||||
assert!(!session.compressed_messages.is_empty());
|
||||
assert_eq!(session.messages.len(), 1);
|
||||
assert!(session.dirty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_is_not_empty_after_compress() {
|
||||
let mut session = Session::default();
|
||||
session.messages.push(Message::new(
|
||||
MessageRole::User,
|
||||
MessageContent::Text("hello".to_string()),
|
||||
));
|
||||
|
||||
session.compress("Summary".to_string());
|
||||
|
||||
assert!(!session.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_need_autoname_default_false() {
|
||||
let session = Session::default();
|
||||
assert!(!session.need_autoname());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_set_autonaming_doesnt_panic_without_autoname() {
|
||||
let mut session = Session::default();
|
||||
session.set_autonaming(true);
|
||||
assert!(!session.need_autoname());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
//! Per-scope tool runtime: resolved functions + live MCP handles +
|
||||
//! call tracker.
|
||||
//!
|
||||
//! `ToolScope` is the unit of tool availability for a single request.
|
||||
//! Every active `RoleLike` (role, session, agent) conceptually owns one.
|
||||
//! The contents are:
|
||||
//!
|
||||
//! * `functions` — the `Functions` declarations visible to the LLM for
|
||||
//! this scope (global tools + role/session/agent filters applied)
|
||||
//! * `mcp_runtime` — live MCP subprocess handles for the servers this
|
||||
//! scope has enabled, keyed by server name
|
||||
//! * `tool_tracker` — per-scope tool call history for auto-continuation
|
||||
//! and looping detection
|
||||
//!
|
||||
//! # Phase 1 Step 6.5 scope
|
||||
//!
|
||||
//! This file introduces the type scaffolding. Scope transitions
|
||||
//! (`use_role`, `use_session`, `use_agent`, `exit_*`) that actually
|
||||
//! build and swap `ToolScope` instances are deferred to Step 8 when
|
||||
//! the entry points (`main.rs`, `repl/mod.rs`) get rewritten to thread
|
||||
//! `RequestContext` through the pipeline. During the bridge window,
|
||||
//! `Config.functions` / `Config.mcp_registry` keep serving today's
|
||||
//! callers and `ToolScope` sits alongside them on `RequestContext` as
|
||||
//! an unused (but compiling and tested) parallel structure.
|
||||
//!
|
||||
//! The fields mirror the plan in `docs/REST-API-ARCHITECTURE.md`
|
||||
//! section 5 and `docs/PHASE-1-IMPLEMENTATION-PLAN.md` Step 6.5.
|
||||
|
||||
use crate::function::{Functions, ToolCallTracker};
|
||||
use crate::mcp::{CatalogItem, ConnectedServer, McpRegistry};
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use bm25::{Document, Language, SearchEngineBuilder};
|
||||
use rmcp::model::{CallToolRequestParams, CallToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ToolScope {
|
||||
pub functions: Functions,
|
||||
pub mcp_runtime: McpRuntime,
|
||||
pub tool_tracker: ToolCallTracker,
|
||||
}
|
||||
|
||||
impl Default for ToolScope {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
functions: Functions::default(),
|
||||
mcp_runtime: McpRuntime::default(),
|
||||
tool_tracker: ToolCallTracker::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct McpRuntime {
|
||||
pub servers: HashMap<String, Arc<ConnectedServer>>,
|
||||
}
|
||||
|
||||
impl McpRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, name: String, handle: Arc<ConnectedServer>) {
|
||||
self.servers.insert(name, handle);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<&Arc<ConnectedServer>> {
|
||||
self.servers.get(name)
|
||||
}
|
||||
|
||||
pub fn server_names(&self) -> Vec<String> {
|
||||
self.servers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn sync_from_registry(&mut self, registry: &McpRegistry) {
|
||||
self.servers.clear();
|
||||
for (name, handle) in registry.running_servers() {
|
||||
self.servers.insert(name.clone(), Arc::clone(handle));
|
||||
}
|
||||
}
|
||||
|
||||
async fn catalog_items(&self, server: &str) -> Result<HashMap<String, CatalogItem>> {
|
||||
let server_handle = self
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("{server} MCP server not found in runtime"))?;
|
||||
let tools = server_handle.list_tools(None).await?;
|
||||
let mut items = HashMap::new();
|
||||
|
||||
for tool in tools.tools {
|
||||
let item = CatalogItem {
|
||||
name: tool.name.to_string(),
|
||||
server: server.to_string(),
|
||||
description: tool.description.unwrap_or_default().to_string(),
|
||||
};
|
||||
items.insert(item.name.clone(), item);
|
||||
}
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
&self,
|
||||
server: &str,
|
||||
query: &str,
|
||||
top_k: usize,
|
||||
) -> Result<Vec<CatalogItem>> {
|
||||
let items = self.catalog_items(server).await?;
|
||||
let docs = items.values().map(|item| Document {
|
||||
id: item.name.clone(),
|
||||
contents: format!(
|
||||
"{}\n{}\nserver:{}",
|
||||
item.name, item.description, item.server
|
||||
),
|
||||
});
|
||||
let engine = SearchEngineBuilder::<String>::with_documents(Language::English, docs).build();
|
||||
|
||||
Ok(engine
|
||||
.search(query, top_k.min(20))
|
||||
.into_iter()
|
||||
.filter_map(|result| items.get(&result.document.id))
|
||||
.take(top_k)
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn describe(&self, server: &str, tool: &str) -> Result<Value> {
|
||||
let server_handle = self
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("{server} MCP server not found in runtime"))?;
|
||||
|
||||
let tool_schema = server_handle
|
||||
.list_tools(None)
|
||||
.await?
|
||||
.tools
|
||||
.into_iter()
|
||||
.find(|item| item.name == tool)
|
||||
.ok_or_else(|| anyhow!("{tool} not found in {server} MCP server catalog"))?
|
||||
.input_schema;
|
||||
|
||||
Ok(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool": {
|
||||
"type": "string",
|
||||
},
|
||||
"arguments": tool_schema
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn invoke(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> Result<CallToolResult> {
|
||||
let server_handle = self
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("Invoked MCP server does not exist: {server}"))?;
|
||||
|
||||
let request = CallToolRequestParams {
|
||||
name: Cow::Owned(tool.to_owned()),
|
||||
arguments: arguments.as_object().cloned(),
|
||||
meta: None,
|
||||
task: None,
|
||||
};
|
||||
|
||||
server_handle.call_tool(request).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
+91
-106
@@ -3,11 +3,12 @@ pub(crate) mod todo;
|
||||
pub(crate) mod user_interaction;
|
||||
|
||||
use crate::{
|
||||
config::{Agent, Config, GlobalConfig},
|
||||
config::{Agent, RequestContext},
|
||||
utils::*,
|
||||
};
|
||||
|
||||
use crate::config::ensure_parent_exists;
|
||||
use crate::config::paths;
|
||||
use crate::mcp::{
|
||||
MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, MCP_INVOKE_META_FUNCTION_NAME_PREFIX,
|
||||
MCP_SEARCH_META_FUNCTION_NAME_PREFIX,
|
||||
@@ -110,7 +111,7 @@ fn extract_shebang_runtime(path: &Path) -> Option<String> {
|
||||
}
|
||||
|
||||
pub async fn eval_tool_calls(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
mut calls: Vec<ToolCall>,
|
||||
) -> Result<Vec<ToolResult>> {
|
||||
let mut output = vec![];
|
||||
@@ -123,9 +124,7 @@ pub async fn eval_tool_calls(
|
||||
}
|
||||
let mut is_all_null = true;
|
||||
for call in calls {
|
||||
if let Some(checker) = &config.read().tool_call_tracker
|
||||
&& let Some(msg) = checker.check_loop(&call.clone())
|
||||
{
|
||||
if let Some(msg) = ctx.tool_scope.tool_tracker.check_loop(&call.clone()) {
|
||||
let dup_msg = format!("{{\"tool_call_loop_alert\":{}}}", &msg.trim());
|
||||
println!(
|
||||
"{}",
|
||||
@@ -136,7 +135,7 @@ pub async fn eval_tool_calls(
|
||||
is_all_null = false;
|
||||
continue;
|
||||
}
|
||||
let mut result = call.eval(config).await?;
|
||||
let mut result = call.eval(ctx).await?;
|
||||
if result.is_null() {
|
||||
result = json!("DONE");
|
||||
} else {
|
||||
@@ -149,16 +148,13 @@ pub async fn eval_tool_calls(
|
||||
}
|
||||
|
||||
if !output.is_empty() {
|
||||
let (has_escalations, summary) = {
|
||||
let cfg = config.read();
|
||||
if cfg.current_depth == 0
|
||||
&& let Some(ref queue) = cfg.root_escalation_queue
|
||||
&& queue.has_pending()
|
||||
{
|
||||
(true, queue.pending_summary())
|
||||
} else {
|
||||
(false, vec![])
|
||||
}
|
||||
let (has_escalations, summary) = if ctx.current_depth == 0
|
||||
&& let Some(queue) = ctx.root_escalation_queue()
|
||||
&& queue.has_pending()
|
||||
{
|
||||
(true, queue.pending_summary())
|
||||
} else {
|
||||
(false, vec![])
|
||||
};
|
||||
|
||||
if has_escalations {
|
||||
@@ -199,7 +195,7 @@ impl Functions {
|
||||
fn install_global_tools() -> Result<()> {
|
||||
info!(
|
||||
"Installing global built-in functions in {}",
|
||||
Config::functions_dir().display()
|
||||
paths::functions_dir().display()
|
||||
);
|
||||
|
||||
for file in FunctionAssets::iter() {
|
||||
@@ -213,7 +209,7 @@ impl Functions {
|
||||
anyhow!("Failed to load embedded function file: {}", file.as_ref())
|
||||
})?;
|
||||
let content = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let file_path = Config::functions_dir().join(file.as_ref());
|
||||
let file_path = paths::functions_dir().join(file.as_ref());
|
||||
let file_extension = file_path
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
@@ -254,7 +250,7 @@ impl Functions {
|
||||
|
||||
info!(
|
||||
"Building global function binaries in {}",
|
||||
Config::functions_bin_dir().display()
|
||||
paths::functions_bin_dir().display()
|
||||
);
|
||||
Self::build_global_function_binaries(visible_tools, None)?;
|
||||
|
||||
@@ -271,7 +267,7 @@ impl Functions {
|
||||
|
||||
info!(
|
||||
"Building global function binaries required by agent: {name} in {}",
|
||||
Config::functions_bin_dir().display()
|
||||
paths::functions_bin_dir().display()
|
||||
);
|
||||
Self::build_global_function_binaries(global_tools, Some(name))?;
|
||||
tools_declarations
|
||||
@@ -279,7 +275,7 @@ impl Functions {
|
||||
debug!("No global tools found for agent: {}", name);
|
||||
Vec::new()
|
||||
};
|
||||
let agent_script_declarations = match Config::agent_functions_file(name) {
|
||||
let agent_script_declarations = match paths::agent_functions_file(name) {
|
||||
Ok(path) if path.exists() => {
|
||||
info!(
|
||||
"Loading functions script for agent: {name} from {}",
|
||||
@@ -290,7 +286,7 @@ impl Functions {
|
||||
|
||||
info!(
|
||||
"Building function binary for agent: {name} in {}",
|
||||
Config::agent_bin_dir(name).display()
|
||||
paths::agent_bin_dir(name).display()
|
||||
);
|
||||
Self::build_agent_tool_binaries(name)?;
|
||||
script_declarations
|
||||
@@ -342,14 +338,6 @@ impl Functions {
|
||||
.extend(user_interaction::user_interaction_function_declarations());
|
||||
}
|
||||
|
||||
pub fn clear_mcp_meta_functions(&mut self) {
|
||||
self.declarations.retain(|d| {
|
||||
!d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX)
|
||||
&& !d.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX)
|
||||
&& !d.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX)
|
||||
});
|
||||
}
|
||||
|
||||
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
|
||||
let mut invoke_function_properties = IndexMap::new();
|
||||
invoke_function_properties.insert(
|
||||
@@ -453,7 +441,7 @@ impl Functions {
|
||||
fn build_global_tool_declarations(
|
||||
enabled_tools: &[String],
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let global_tools_directory = Config::global_tools_dir();
|
||||
let global_tools_directory = paths::global_tools_dir();
|
||||
let mut function_declarations = Vec::new();
|
||||
|
||||
for tool in enabled_tools {
|
||||
@@ -542,7 +530,7 @@ impl Functions {
|
||||
bail!("Unsupported tool file extension: {}", language.as_ref());
|
||||
}
|
||||
|
||||
let tool_path = Config::global_tools_dir().join(tool);
|
||||
let tool_path = paths::global_tools_dir().join(tool);
|
||||
let custom_runtime = extract_shebang_runtime(&tool_path);
|
||||
Self::build_binaries(
|
||||
binary_name,
|
||||
@@ -556,7 +544,7 @@ impl Functions {
|
||||
}
|
||||
|
||||
fn clear_agent_bin_dir(name: &str) -> Result<()> {
|
||||
let agent_bin_directory = Config::agent_bin_dir(name);
|
||||
let agent_bin_directory = paths::agent_bin_dir(name);
|
||||
if !agent_bin_directory.exists() {
|
||||
debug!(
|
||||
"Creating agent bin directory: {}",
|
||||
@@ -575,7 +563,7 @@ impl Functions {
|
||||
}
|
||||
|
||||
fn clear_global_functions_bin_dir() -> Result<()> {
|
||||
let bin_dir = Config::functions_bin_dir();
|
||||
let bin_dir = paths::functions_bin_dir();
|
||||
if !bin_dir.exists() {
|
||||
fs::create_dir_all(&bin_dir)?;
|
||||
}
|
||||
@@ -590,7 +578,7 @@ impl Functions {
|
||||
}
|
||||
|
||||
fn build_agent_tool_binaries(name: &str) -> Result<()> {
|
||||
let tools_file = Config::agent_functions_file(name)?;
|
||||
let tools_file = paths::agent_functions_file(name)?;
|
||||
let language = Language::from(
|
||||
&tools_file
|
||||
.extension()
|
||||
@@ -619,18 +607,18 @@ impl Functions {
|
||||
use native::runtime;
|
||||
let (binary_file, binary_script_file) = match binary_type {
|
||||
BinaryType::Tool(None) => (
|
||||
Config::functions_bin_dir().join(format!("{binary_name}.cmd")),
|
||||
Config::functions_bin_dir()
|
||||
paths::functions_bin_dir().join(format!("{binary_name}.cmd")),
|
||||
paths::functions_bin_dir()
|
||||
.join(format!("run-{binary_name}.{}", language.to_extension())),
|
||||
),
|
||||
BinaryType::Tool(Some(agent_name)) => (
|
||||
Config::agent_bin_dir(agent_name).join(format!("{binary_name}.cmd")),
|
||||
Config::agent_bin_dir(agent_name)
|
||||
paths::agent_bin_dir(agent_name).join(format!("{binary_name}.cmd")),
|
||||
paths::agent_bin_dir(agent_name)
|
||||
.join(format!("run-{binary_name}.{}", language.to_extension())),
|
||||
),
|
||||
BinaryType::Agent => (
|
||||
Config::agent_bin_dir(binary_name).join(format!("{binary_name}.cmd")),
|
||||
Config::agent_bin_dir(binary_name)
|
||||
paths::agent_bin_dir(binary_name).join(format!("{binary_name}.cmd")),
|
||||
paths::agent_bin_dir(binary_name)
|
||||
.join(format!("run-{binary_name}.{}", language.to_extension())),
|
||||
),
|
||||
};
|
||||
@@ -655,7 +643,7 @@ impl Functions {
|
||||
let to_script_path = |p: &str| -> String { p.replace('\\', "/") };
|
||||
let content = match binary_type {
|
||||
BinaryType::Tool(None) => {
|
||||
let root_dir = Config::functions_dir();
|
||||
let root_dir = paths::functions_dir();
|
||||
let tool_path = format!(
|
||||
"{}/{binary_name}",
|
||||
&Config::global_tools_dir().to_string_lossy()
|
||||
@@ -666,7 +654,7 @@ impl Functions {
|
||||
.replace("{tool_path}", &to_script_path(&tool_path))
|
||||
}
|
||||
BinaryType::Tool(Some(agent_name)) => {
|
||||
let root_dir = Config::agent_data_dir(agent_name);
|
||||
let root_dir = paths::agent_data_dir(agent_name);
|
||||
let tool_path = format!(
|
||||
"{}/{binary_name}",
|
||||
&Config::global_tools_dir().to_string_lossy()
|
||||
@@ -680,12 +668,12 @@ impl Functions {
|
||||
.replace("{agent_name}", binary_name)
|
||||
.replace(
|
||||
"{config_dir}",
|
||||
&to_script_path(&Config::config_dir().to_string_lossy()),
|
||||
&to_script_path(&paths::config_dir().to_string_lossy()),
|
||||
),
|
||||
}
|
||||
.replace(
|
||||
"{prompt_utils_file}",
|
||||
&to_script_path(&Config::bash_prompt_utils_file().to_string_lossy()),
|
||||
&to_script_path(&paths::bash_prompt_utils_file().to_string_lossy()),
|
||||
);
|
||||
if binary_script_file.exists() {
|
||||
fs::remove_file(&binary_script_file)?;
|
||||
@@ -769,11 +757,11 @@ impl Functions {
|
||||
use std::os::unix::prelude::PermissionsExt;
|
||||
|
||||
let binary_file = match binary_type {
|
||||
BinaryType::Tool(None) => Config::functions_bin_dir().join(binary_name),
|
||||
BinaryType::Tool(None) => paths::functions_bin_dir().join(binary_name),
|
||||
BinaryType::Tool(Some(agent_name)) => {
|
||||
Config::agent_bin_dir(agent_name).join(binary_name)
|
||||
paths::agent_bin_dir(agent_name).join(binary_name)
|
||||
}
|
||||
BinaryType::Agent => Config::agent_bin_dir(binary_name).join(binary_name),
|
||||
BinaryType::Agent => paths::agent_bin_dir(binary_name).join(binary_name),
|
||||
};
|
||||
info!(
|
||||
"Building binary for function: {} ({})",
|
||||
@@ -795,10 +783,10 @@ impl Functions {
|
||||
let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let mut content = match binary_type {
|
||||
BinaryType::Tool(None) => {
|
||||
let root_dir = Config::functions_dir();
|
||||
let root_dir = paths::functions_dir();
|
||||
let tool_path = format!(
|
||||
"{}/{binary_name}",
|
||||
&Config::global_tools_dir().to_string_lossy()
|
||||
&paths::global_tools_dir().to_string_lossy()
|
||||
);
|
||||
content_template
|
||||
.replace("{function_name}", binary_name)
|
||||
@@ -806,10 +794,10 @@ impl Functions {
|
||||
.replace("{tool_path}", &tool_path)
|
||||
}
|
||||
BinaryType::Tool(Some(agent_name)) => {
|
||||
let root_dir = Config::agent_data_dir(agent_name);
|
||||
let root_dir = paths::agent_data_dir(agent_name);
|
||||
let tool_path = format!(
|
||||
"{}/{binary_name}",
|
||||
&Config::global_tools_dir().to_string_lossy()
|
||||
&paths::global_tools_dir().to_string_lossy()
|
||||
);
|
||||
content_template
|
||||
.replace("{function_name}", binary_name)
|
||||
@@ -818,11 +806,11 @@ impl Functions {
|
||||
}
|
||||
BinaryType::Agent => content_template
|
||||
.replace("{agent_name}", binary_name)
|
||||
.replace("{config_dir}", &Config::config_dir().to_string_lossy()),
|
||||
.replace("{config_dir}", &paths::config_dir().to_string_lossy()),
|
||||
}
|
||||
.replace(
|
||||
"{prompt_utils_file}",
|
||||
&Config::bash_prompt_utils_file().to_string_lossy(),
|
||||
&paths::bash_prompt_utils_file().to_string_lossy(),
|
||||
);
|
||||
|
||||
if let Some(rt) = custom_runtime
|
||||
@@ -952,16 +940,15 @@ impl ToolCall {
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn eval(&self, config: &GlobalConfig) -> Result<Value> {
|
||||
let (call_name, cmd_name, mut cmd_args, envs) = match &config.read().agent {
|
||||
Some(agent) => self.extract_call_config_from_agent(config, agent)?,
|
||||
None => self.extract_call_config_from_config(config)?,
|
||||
pub async fn eval(&self, ctx: &mut RequestContext) -> Result<Value> {
|
||||
let agent = ctx.agent.clone();
|
||||
let functions = ctx.tool_scope.functions.clone();
|
||||
let current_depth = ctx.current_depth;
|
||||
let agent_name = agent.as_ref().map(|agent| agent.name().to_owned());
|
||||
let (call_name, cmd_name, mut cmd_args, envs) = match agent.as_ref() {
|
||||
Some(agent) => self.extract_call_config_from_agent(&functions, agent)?,
|
||||
None => self.extract_call_config_from_ctx(&functions)?,
|
||||
};
|
||||
let agent_name = config
|
||||
.read()
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|agent| agent.name().to_owned());
|
||||
|
||||
let json_data = if self.arguments.is_object() {
|
||||
self.arguments.clone()
|
||||
@@ -981,20 +968,22 @@ impl ToolCall {
|
||||
|
||||
let prompt = format!("Call {cmd_name} {}", cmd_args.join(" "));
|
||||
|
||||
if *IS_STDOUT_TERMINAL && config.read().current_depth == 0 {
|
||||
if *IS_STDOUT_TERMINAL && current_depth == 0 {
|
||||
println!("{}", dimmed_text(&prompt));
|
||||
}
|
||||
|
||||
let output = match cmd_name.as_str() {
|
||||
_ if cmd_name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) => {
|
||||
Self::search_mcp_tools(config, &cmd_name, &json_data).unwrap_or_else(|e| {
|
||||
let error_msg = format!("MCP search failed: {e}");
|
||||
eprintln!("{}", warning_text(&format!("⚠️ {error_msg} ⚠️")));
|
||||
json!({"tool_call_error": error_msg})
|
||||
})
|
||||
Self::search_mcp_tools(ctx, &cmd_name, &json_data)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
let error_msg = format!("MCP search failed: {e}");
|
||||
eprintln!("{}", warning_text(&format!("⚠️ {error_msg} ⚠️")));
|
||||
json!({"tool_call_error": error_msg})
|
||||
})
|
||||
}
|
||||
_ if cmd_name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) => {
|
||||
Self::describe_mcp_tool(config, &cmd_name, json_data)
|
||||
Self::describe_mcp_tool(ctx, &cmd_name, json_data)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
let error_msg = format!("MCP describe failed: {e}");
|
||||
@@ -1003,7 +992,7 @@ impl ToolCall {
|
||||
})
|
||||
}
|
||||
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
|
||||
Self::invoke_mcp_tool(config, &cmd_name, &json_data)
|
||||
Self::invoke_mcp_tool(ctx, &cmd_name, &json_data)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
let error_msg = format!("MCP tool invocation failed: {e}");
|
||||
@@ -1012,14 +1001,14 @@ impl ToolCall {
|
||||
})
|
||||
}
|
||||
_ if cmd_name.starts_with(TODO_FUNCTION_PREFIX) => {
|
||||
todo::handle_todo_tool(config, &cmd_name, &json_data).unwrap_or_else(|e| {
|
||||
todo::handle_todo_tool(ctx, &cmd_name, &json_data).unwrap_or_else(|e| {
|
||||
let error_msg = format!("Todo tool failed: {e}");
|
||||
eprintln!("{}", warning_text(&format!("⚠️ {error_msg} ⚠️")));
|
||||
json!({"tool_call_error": error_msg})
|
||||
})
|
||||
}
|
||||
_ if cmd_name.starts_with(SUPERVISOR_FUNCTION_PREFIX) => {
|
||||
supervisor::handle_supervisor_tool(config, &cmd_name, &json_data)
|
||||
supervisor::handle_supervisor_tool(ctx, &cmd_name, &json_data)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
let error_msg = format!("Supervisor tool failed: {e}");
|
||||
@@ -1028,7 +1017,7 @@ impl ToolCall {
|
||||
})
|
||||
}
|
||||
_ if cmd_name.starts_with(USER_FUNCTION_PREFIX) => {
|
||||
user_interaction::handle_user_tool(config, &cmd_name, &json_data)
|
||||
user_interaction::handle_user_tool(ctx, &cmd_name, &json_data)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
let error_msg = format!("User interaction failed: {e}");
|
||||
@@ -1051,7 +1040,7 @@ impl ToolCall {
|
||||
}
|
||||
|
||||
async fn describe_mcp_tool(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
cmd_name: &str,
|
||||
json_data: Value,
|
||||
) -> Result<Value> {
|
||||
@@ -1061,18 +1050,19 @@ impl ToolCall {
|
||||
.ok_or_else(|| anyhow!("Missing 'tool' in arguments"))?
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?;
|
||||
let registry_arc = {
|
||||
let cfg = config.read();
|
||||
cfg.mcp_registry
|
||||
.clone()
|
||||
.with_context(|| "MCP is not configured")?
|
||||
};
|
||||
|
||||
let result = registry_arc.describe(&server_id, tool).await?;
|
||||
let result = ctx
|
||||
.tool_scope
|
||||
.mcp_runtime
|
||||
.describe(&server_id, tool)
|
||||
.await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
fn search_mcp_tools(config: &GlobalConfig, cmd_name: &str, json_data: &Value) -> Result<Value> {
|
||||
async fn search_mcp_tools(
|
||||
ctx: &RequestContext,
|
||||
cmd_name: &str,
|
||||
json_data: &Value,
|
||||
) -> Result<Value> {
|
||||
let server = cmd_name.replace(&format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_"), "");
|
||||
let query = json_data
|
||||
.get("query")
|
||||
@@ -1085,15 +1075,12 @@ impl ToolCall {
|
||||
.unwrap_or_else(|| Value::from(8u64))
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow!("Invalid 'top_k' in arguments"))? as usize;
|
||||
let registry_arc = {
|
||||
let cfg = config.read();
|
||||
cfg.mcp_registry
|
||||
.clone()
|
||||
.with_context(|| "MCP is not configured")?
|
||||
};
|
||||
|
||||
let catalog_items = registry_arc
|
||||
.search_tools_server(&server, query, top_k)
|
||||
let catalog_items = ctx
|
||||
.tool_scope
|
||||
.mcp_runtime
|
||||
.search(&server, query, top_k)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|it| serde_json::to_value(&it).unwrap_or_default())
|
||||
.collect();
|
||||
@@ -1101,7 +1088,7 @@ impl ToolCall {
|
||||
}
|
||||
|
||||
async fn invoke_mcp_tool(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
cmd_name: &str,
|
||||
json_data: &Value,
|
||||
) -> Result<Value> {
|
||||
@@ -1115,20 +1102,18 @@ impl ToolCall {
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({}));
|
||||
let registry_arc = {
|
||||
let cfg = config.read();
|
||||
cfg.mcp_registry
|
||||
.clone()
|
||||
.with_context(|| "MCP is not configured")?
|
||||
};
|
||||
|
||||
let result = registry_arc.invoke(&server, tool, arguments).await?;
|
||||
let result = ctx
|
||||
.tool_scope
|
||||
.mcp_runtime
|
||||
.invoke(&server, tool, arguments)
|
||||
.await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
fn extract_call_config_from_agent(
|
||||
&self,
|
||||
config: &GlobalConfig,
|
||||
functions: &Functions,
|
||||
agent: &Agent,
|
||||
) -> Result<CallConfig> {
|
||||
let function_name = self.name.clone();
|
||||
@@ -1151,13 +1136,13 @@ impl ToolCall {
|
||||
))
|
||||
}
|
||||
}
|
||||
None => self.extract_call_config_from_config(config),
|
||||
None => self.extract_call_config_from_ctx(functions),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_call_config_from_config(&self, config: &GlobalConfig) -> Result<CallConfig> {
|
||||
fn extract_call_config_from_ctx(&self, functions: &Functions) -> Result<CallConfig> {
|
||||
let function_name = self.name.clone();
|
||||
match config.read().functions.contains(&function_name) {
|
||||
match functions.contains(&function_name) {
|
||||
true => Ok((
|
||||
function_name.clone(),
|
||||
function_name,
|
||||
@@ -1179,12 +1164,12 @@ pub fn run_llm_function(
|
||||
let mut command_name = cmd_name.clone();
|
||||
if let Some(agent_name) = agent_name {
|
||||
command_name = cmd_args[0].clone();
|
||||
let dir = Config::agent_bin_dir(&agent_name);
|
||||
let dir = paths::agent_bin_dir(&agent_name);
|
||||
if dir.exists() {
|
||||
bin_dirs.push(dir);
|
||||
}
|
||||
} else {
|
||||
bin_dirs.push(Config::functions_bin_dir());
|
||||
bin_dirs.push(paths::functions_bin_dir());
|
||||
}
|
||||
let current_path = env::var("PATH").context("No PATH environment variable")?;
|
||||
let prepend_path = bin_dirs
|
||||
|
||||
+201
-131
@@ -1,12 +1,11 @@
|
||||
use super::{FunctionDeclaration, JsonSchema};
|
||||
use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{Config, GlobalConfig, Input, Role, RoleLike};
|
||||
use crate::supervisor::escalation::EscalationQueue;
|
||||
use crate::config::{Agent, AppState, Input, RequestContext, Role, RoleLike};
|
||||
use crate::supervisor::mailbox::{Envelope, EnvelopePayload, Inbox};
|
||||
use crate::supervisor::{AgentExitStatus, AgentHandle, AgentResult};
|
||||
use crate::supervisor::{AgentExitStatus, AgentHandle, AgentResult, Supervisor};
|
||||
use crate::utils::{AbortSignal, create_abort_signal};
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use chrono::Utc;
|
||||
use indexmap::IndexMap;
|
||||
use log::debug;
|
||||
@@ -300,7 +299,7 @@ pub fn teammate_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
}
|
||||
|
||||
pub async fn handle_supervisor_tool(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
cmd_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<Value> {
|
||||
@@ -309,42 +308,47 @@ pub async fn handle_supervisor_tool(
|
||||
.unwrap_or(cmd_name);
|
||||
|
||||
match action {
|
||||
"spawn" => handle_spawn(config, args).await,
|
||||
"check" => handle_check(config, args).await,
|
||||
"collect" => handle_collect(config, args).await,
|
||||
"list" => handle_list(config),
|
||||
"cancel" => handle_cancel(config, args),
|
||||
"send_message" => handle_send_message(config, args),
|
||||
"check_inbox" => handle_check_inbox(config),
|
||||
"task_create" => handle_task_create(config, args),
|
||||
"task_list" => handle_task_list(config),
|
||||
"task_complete" => handle_task_complete(config, args).await,
|
||||
"task_fail" => handle_task_fail(config, args),
|
||||
"reply_escalation" => handle_reply_escalation(config, args),
|
||||
"spawn" => handle_spawn(ctx, args).await,
|
||||
"check" => handle_check(ctx, args).await,
|
||||
"collect" => handle_collect(ctx, args).await,
|
||||
"list" => handle_list(ctx),
|
||||
"cancel" => handle_cancel(ctx, args),
|
||||
"send_message" => handle_send_message(ctx, args),
|
||||
"check_inbox" => handle_check_inbox(ctx),
|
||||
"task_create" => handle_task_create(ctx, args),
|
||||
"task_list" => handle_task_list(ctx),
|
||||
"task_complete" => handle_task_complete(ctx, args).await,
|
||||
"task_fail" => handle_task_fail(ctx, args),
|
||||
"reply_escalation" => handle_reply_escalation(ctx, args),
|
||||
_ => bail!("Unknown supervisor action: {action}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_child_agent(
|
||||
child_config: GlobalConfig,
|
||||
mut child_ctx: RequestContext,
|
||||
initial_input: Input,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String>> + Send>> {
|
||||
Box::pin(async move {
|
||||
let mut accumulated_output = String::new();
|
||||
let mut input = initial_input;
|
||||
let app = Arc::clone(&child_ctx.app.config);
|
||||
|
||||
loop {
|
||||
let client = input.create_client()?;
|
||||
child_config.write().before_chat_completion(&input)?;
|
||||
child_ctx.before_chat_completion(&input)?;
|
||||
|
||||
let (output, tool_results) =
|
||||
call_chat_completions(&input, false, false, client.as_ref(), abort_signal.clone())
|
||||
.await?;
|
||||
let (output, tool_results) = call_chat_completions(
|
||||
&input,
|
||||
false,
|
||||
false,
|
||||
client.as_ref(),
|
||||
&mut child_ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
child_config
|
||||
.write()
|
||||
.after_chat_completion(&input, &output, &tool_results)?;
|
||||
child_ctx.after_chat_completion(app.as_ref(), &input, &output, &tool_results)?;
|
||||
|
||||
if !output.is_empty() {
|
||||
if !accumulated_output.is_empty() {
|
||||
@@ -360,7 +364,7 @@ fn run_child_agent(
|
||||
input = input.merge_tool_results(output, tool_results);
|
||||
}
|
||||
|
||||
if let Some(ref supervisor) = child_config.read().supervisor {
|
||||
if let Some(supervisor) = child_ctx.supervisor.clone() {
|
||||
supervisor.read().cancel_all();
|
||||
}
|
||||
|
||||
@@ -368,7 +372,58 @@ fn run_child_agent(
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
async fn populate_agent_mcp_runtime(ctx: &mut RequestContext, server_ids: &[String]) -> Result<()> {
|
||||
if !ctx.app.config.mcp_server_support {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let app = Arc::clone(&ctx.app);
|
||||
let server_specs = app
|
||||
.mcp_config
|
||||
.as_ref()
|
||||
.map(|mcp_config| {
|
||||
server_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
mcp_config
|
||||
.mcp_servers
|
||||
.get(id)
|
||||
.cloned()
|
||||
.map(|spec| (id.clone(), spec))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
for (id, spec) in server_specs {
|
||||
let handle = app
|
||||
.mcp_factory
|
||||
.acquire(&id, &spec, app.mcp_log_path.as_deref())
|
||||
.await?;
|
||||
ctx.tool_scope.mcp_runtime.insert(id, handle);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sync_agent_functions_to_ctx(ctx: &mut RequestContext) -> Result<()> {
|
||||
let server_names = ctx.tool_scope.mcp_runtime.server_names();
|
||||
let functions = {
|
||||
let agent = ctx
|
||||
.agent
|
||||
.as_mut()
|
||||
.with_context(|| "Agent should be initialized")?;
|
||||
if !server_names.is_empty() {
|
||||
agent.append_mcp_meta_functions(server_names);
|
||||
}
|
||||
agent.functions().clone()
|
||||
};
|
||||
|
||||
ctx.tool_scope.functions = functions;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_spawn(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let agent_name = args
|
||||
.get("agent")
|
||||
.and_then(Value::as_str)
|
||||
@@ -385,10 +440,10 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
let agent_id = format!("agent_{agent_name}_{short_uuid}");
|
||||
|
||||
let (max_depth, current_depth) = {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active; Agent spawning not enabled"))?;
|
||||
let sup = supervisor.read();
|
||||
if sup.active_count() >= sup.max_concurrent() {
|
||||
@@ -401,7 +456,7 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
),
|
||||
}));
|
||||
}
|
||||
(sup.max_depth(), cfg.current_depth + 1)
|
||||
(sup.max_depth(), ctx.current_depth + 1)
|
||||
};
|
||||
|
||||
if current_depth > max_depth {
|
||||
@@ -413,37 +468,68 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
|
||||
let child_inbox = Arc::new(Inbox::new());
|
||||
|
||||
{
|
||||
let mut cfg = config.write();
|
||||
if cfg.root_escalation_queue.is_none() {
|
||||
cfg.root_escalation_queue = Some(Arc::new(EscalationQueue::new()));
|
||||
}
|
||||
}
|
||||
|
||||
let child_config: GlobalConfig = {
|
||||
let mut child_cfg = config.read().clone();
|
||||
|
||||
child_cfg.parent_supervisor = child_cfg.supervisor.clone();
|
||||
child_cfg.agent = None;
|
||||
child_cfg.session = None;
|
||||
child_cfg.rag = None;
|
||||
child_cfg.supervisor = None;
|
||||
child_cfg.last_message = None;
|
||||
child_cfg.tool_call_tracker = None;
|
||||
|
||||
child_cfg.stream = false;
|
||||
child_cfg.save = false;
|
||||
child_cfg.current_depth = current_depth;
|
||||
child_cfg.inbox = Some(Arc::clone(&child_inbox));
|
||||
child_cfg.self_agent_id = Some(agent_id.clone());
|
||||
|
||||
Arc::new(RwLock::new(child_cfg))
|
||||
};
|
||||
ctx.ensure_root_escalation_queue();
|
||||
|
||||
let child_abort = create_abort_signal();
|
||||
Config::use_agent(&child_config, &agent_name, None, child_abort.clone()).await?;
|
||||
|
||||
let input = Input::from_str(&child_config, &prompt, None);
|
||||
if !ctx.app.config.function_calling_support {
|
||||
bail!("Please enable function calling support before using the agent.");
|
||||
}
|
||||
|
||||
let app_config = Arc::clone(&ctx.app.config);
|
||||
let current_model = ctx.current_model().clone();
|
||||
let info_flag = ctx.info_flag;
|
||||
let child_app_state = Arc::new(AppState {
|
||||
config: Arc::new(app_config.as_ref().clone()),
|
||||
vault: ctx.app.vault.clone(),
|
||||
mcp_factory: ctx.app.mcp_factory.clone(),
|
||||
rag_cache: ctx.app.rag_cache.clone(),
|
||||
mcp_config: ctx.app.mcp_config.clone(),
|
||||
mcp_log_path: ctx.app.mcp_log_path.clone(),
|
||||
});
|
||||
let agent = Agent::init(
|
||||
app_config.as_ref(),
|
||||
child_app_state.as_ref(),
|
||||
¤t_model,
|
||||
info_flag,
|
||||
&agent_name,
|
||||
child_abort.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let agent_mcp_servers = agent.mcp_server_names().to_vec();
|
||||
let session = agent.agent_session().map(|v| v.to_string());
|
||||
let should_init_supervisor = agent.can_spawn_agents();
|
||||
let max_concurrent = agent.max_concurrent_agents();
|
||||
let max_depth = agent.max_agent_depth();
|
||||
let mut child_ctx = RequestContext::new_for_child(
|
||||
Arc::clone(&child_app_state),
|
||||
ctx,
|
||||
current_depth,
|
||||
Arc::clone(&child_inbox),
|
||||
agent_id.clone(),
|
||||
);
|
||||
child_ctx.rag = agent.rag();
|
||||
child_ctx.agent = Some(agent);
|
||||
if should_init_supervisor {
|
||||
child_ctx.supervisor = Some(Arc::new(RwLock::new(Supervisor::new(
|
||||
max_concurrent,
|
||||
max_depth,
|
||||
))));
|
||||
}
|
||||
|
||||
if let Some(session) = session {
|
||||
child_ctx
|
||||
.use_session(app_config.as_ref(), Some(&session), child_abort.clone())
|
||||
.await?;
|
||||
sync_agent_functions_to_ctx(&mut child_ctx)?;
|
||||
} else {
|
||||
populate_agent_mcp_runtime(&mut child_ctx, &agent_mcp_servers).await?;
|
||||
sync_agent_functions_to_ctx(&mut child_ctx)?;
|
||||
child_ctx.init_agent_shared_variables()?;
|
||||
}
|
||||
|
||||
let input = Input::from_str(&child_ctx, &prompt, None);
|
||||
|
||||
debug!("Spawning child agent '{agent_name}' as '{agent_id}'");
|
||||
|
||||
@@ -452,7 +538,7 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
let spawn_abort = child_abort.clone();
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let result = run_child_agent(child_config, input, spawn_abort).await;
|
||||
let result = run_child_agent(child_ctx, input, spawn_abort).await;
|
||||
|
||||
match result {
|
||||
Ok(output) => Ok(AgentResult {
|
||||
@@ -479,15 +565,13 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
join_handle,
|
||||
};
|
||||
|
||||
{
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
sup.register(handle)?;
|
||||
}
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
sup.register(handle)?;
|
||||
|
||||
Ok(json!({
|
||||
"status": "ok",
|
||||
@@ -497,24 +581,24 @@ async fn handle_spawn(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_check(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
async fn handle_check(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let id = args
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'id' is required"))?;
|
||||
|
||||
let is_finished = {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let sup = supervisor.read();
|
||||
sup.is_finished(id)
|
||||
};
|
||||
|
||||
match is_finished {
|
||||
Some(true) => handle_collect(config, args).await,
|
||||
Some(true) => handle_collect(ctx, args).await,
|
||||
Some(false) => Ok(json!({
|
||||
"status": "pending",
|
||||
"id": id,
|
||||
@@ -527,17 +611,17 @@ async fn handle_check(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_collect(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let id = args
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'id' is required"))?;
|
||||
|
||||
let handle = {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
sup.take(id)
|
||||
@@ -551,7 +635,7 @@ async fn handle_collect(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
.map_err(|e| anyhow!("Agent task panicked: {e}"))?
|
||||
.map_err(|e| anyhow!("Agent failed: {e}"))?;
|
||||
|
||||
let output = summarize_output(config, &result.agent_name, &result.output).await?;
|
||||
let output = summarize_output(ctx, &result.agent_name, &result.output).await?;
|
||||
|
||||
Ok(json!({
|
||||
"status": "completed",
|
||||
@@ -568,11 +652,11 @@ async fn handle_collect(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_list(config: &GlobalConfig) -> Result<Value> {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
fn handle_list(ctx: &mut RequestContext) -> Result<Value> {
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let sup = supervisor.read();
|
||||
|
||||
@@ -596,16 +680,16 @@ fn handle_list(config: &GlobalConfig) -> Result<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
fn handle_cancel(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
fn handle_cancel(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let id = args
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'id' is required"))?;
|
||||
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
|
||||
@@ -624,7 +708,7 @@ fn handle_cancel(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_send_message(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
fn handle_send_message(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let id = args
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
@@ -634,24 +718,19 @@ fn handle_send_message(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'message' is required"))?;
|
||||
|
||||
let cfg = config.read();
|
||||
|
||||
// Determine sender identity: self_agent_id (child), agent name (parent), or "parent"
|
||||
let sender = cfg
|
||||
let sender = ctx
|
||||
.self_agent_id
|
||||
.clone()
|
||||
.or_else(|| cfg.agent.as_ref().map(|a| a.name().to_string()))
|
||||
.or_else(|| ctx.agent.as_ref().map(|a| a.name().to_string()))
|
||||
.unwrap_or_else(|| "parent".to_string());
|
||||
|
||||
// Try local supervisor first (parent -> child routing)
|
||||
let inbox = cfg
|
||||
let inbox = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.and_then(|sup| sup.read().inbox(id).cloned());
|
||||
|
||||
// Fall back to parent_supervisor (sibling -> sibling routing)
|
||||
let inbox = inbox.or_else(|| {
|
||||
cfg.parent_supervisor
|
||||
ctx.parent_supervisor
|
||||
.as_ref()
|
||||
.and_then(|sup| sup.read().inbox(id).cloned())
|
||||
});
|
||||
@@ -679,9 +758,8 @@ fn handle_send_message(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_check_inbox(config: &GlobalConfig) -> Result<Value> {
|
||||
let cfg = config.read();
|
||||
match &cfg.inbox {
|
||||
fn handle_check_inbox(ctx: &mut RequestContext) -> Result<Value> {
|
||||
match ctx.inbox.as_ref() {
|
||||
Some(inbox) => {
|
||||
let messages: Vec<Value> = inbox
|
||||
.drain()
|
||||
@@ -707,7 +785,7 @@ fn handle_check_inbox(config: &GlobalConfig) -> Result<Value> {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_reply_escalation(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
fn handle_reply_escalation(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let escalation_id = args
|
||||
.get("escalation_id")
|
||||
.and_then(Value::as_str)
|
||||
@@ -717,12 +795,10 @@ fn handle_reply_escalation(config: &GlobalConfig, args: &Value) -> Result<Value>
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'reply' is required"))?;
|
||||
|
||||
let queue = {
|
||||
let cfg = config.read();
|
||||
cfg.root_escalation_queue
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("No escalation queue available"))?
|
||||
};
|
||||
let queue = ctx
|
||||
.escalation_queue
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("No escalation queue available"))?;
|
||||
|
||||
match queue.take(escalation_id) {
|
||||
Some(request) => {
|
||||
@@ -742,7 +818,7 @@ fn handle_reply_escalation(config: &GlobalConfig, args: &Value) -> Result<Value>
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_task_create(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
fn handle_task_create(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let subject = args
|
||||
.get("subject")
|
||||
.and_then(Value::as_str)
|
||||
@@ -768,10 +844,10 @@ fn handle_task_create(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
bail!("'prompt' is required when 'agent' is set");
|
||||
}
|
||||
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
|
||||
@@ -805,11 +881,11 @@ fn handle_task_create(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn handle_task_list(config: &GlobalConfig) -> Result<Value> {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
fn handle_task_list(ctx: &mut RequestContext) -> Result<Value> {
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let sup = supervisor.read();
|
||||
|
||||
@@ -834,17 +910,17 @@ fn handle_task_list(config: &GlobalConfig) -> Result<Value> {
|
||||
Ok(json!({ "tasks": tasks }))
|
||||
}
|
||||
|
||||
async fn handle_task_complete(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
async fn handle_task_complete(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let task_id = args
|
||||
.get("task_id")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'task_id' is required"))?;
|
||||
|
||||
let (newly_runnable, dispatchable) = {
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
|
||||
@@ -884,7 +960,7 @@ async fn handle_task_complete(config: &GlobalConfig, args: &Value) -> Result<Val
|
||||
"agent": agent,
|
||||
"prompt": prompt,
|
||||
});
|
||||
match handle_spawn(config, &spawn_args).await {
|
||||
match handle_spawn(ctx, &spawn_args).await {
|
||||
Ok(result) => {
|
||||
let agent_id = result
|
||||
.get("id")
|
||||
@@ -916,16 +992,16 @@ async fn handle_task_complete(config: &GlobalConfig, args: &Value) -> Result<Val
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn handle_task_fail(config: &GlobalConfig, args: &Value) -> Result<Value> {
|
||||
fn handle_task_fail(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let task_id = args
|
||||
.get("task_id")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("'task_id' is required"))?;
|
||||
|
||||
let cfg = config.read();
|
||||
let supervisor = cfg
|
||||
let supervisor = ctx
|
||||
.supervisor
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
|
||||
@@ -958,17 +1034,12 @@ Rules:
|
||||
- Use bullet points for multiple findings
|
||||
- If the output contains a final answer or conclusion, lead with it"#;
|
||||
|
||||
async fn summarize_output(config: &GlobalConfig, agent_name: &str, output: &str) -> Result<String> {
|
||||
let (threshold, summarization_model_id) = {
|
||||
let cfg = config.read();
|
||||
match cfg.agent.as_ref() {
|
||||
Some(agent) => (
|
||||
agent.summarization_threshold(),
|
||||
agent.summarization_model().map(|s| s.to_string()),
|
||||
),
|
||||
None => return Ok(output.to_string()),
|
||||
}
|
||||
async fn summarize_output(ctx: &RequestContext, agent_name: &str, output: &str) -> Result<String> {
|
||||
let Some(agent) = ctx.agent.as_ref() else {
|
||||
return Ok(output.to_string());
|
||||
};
|
||||
let threshold = agent.summarization_threshold();
|
||||
let summarization_model_id = agent.summarization_model().map(|s| s.to_string());
|
||||
|
||||
if output.len() < threshold {
|
||||
debug!(
|
||||
@@ -987,12 +1058,11 @@ async fn summarize_output(config: &GlobalConfig, agent_name: &str, output: &str)
|
||||
threshold
|
||||
);
|
||||
|
||||
let model = {
|
||||
let cfg = config.read();
|
||||
match summarization_model_id {
|
||||
Some(ref model_id) => Model::retrieve_model(&cfg, model_id, ModelType::Chat)?,
|
||||
None => cfg.current_model().clone(),
|
||||
let model = match summarization_model_id {
|
||||
Some(ref model_id) => {
|
||||
Model::retrieve_model(ctx.app.config.as_ref(), model_id, ModelType::Chat)?
|
||||
}
|
||||
None => ctx.current_model().clone(),
|
||||
};
|
||||
|
||||
let mut role = Role::new("summarizer", SUMMARIZATION_PROMPT);
|
||||
@@ -1002,7 +1072,7 @@ async fn summarize_output(config: &GlobalConfig, agent_name: &str, output: &str)
|
||||
"Summarize the following sub-agent output from '{}':\n\n{}",
|
||||
agent_name, output
|
||||
);
|
||||
let input = Input::from_str(config, &user_message, Some(role));
|
||||
let input = Input::from_str(ctx, &user_message, Some(role));
|
||||
|
||||
let summary = input.fetch_chat_text().await?;
|
||||
|
||||
|
||||
+21
-55
@@ -1,5 +1,5 @@
|
||||
use super::{FunctionDeclaration, JsonSchema};
|
||||
use crate::config::GlobalConfig;
|
||||
use crate::config::RequestContext;
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use indexmap::IndexMap;
|
||||
@@ -89,38 +89,28 @@ pub fn todo_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
]
|
||||
}
|
||||
|
||||
pub fn handle_todo_tool(config: &GlobalConfig, cmd_name: &str, args: &Value) -> Result<Value> {
|
||||
pub fn handle_todo_tool(ctx: &mut RequestContext, cmd_name: &str, args: &Value) -> Result<Value> {
|
||||
let action = cmd_name
|
||||
.strip_prefix(TODO_FUNCTION_PREFIX)
|
||||
.unwrap_or(cmd_name);
|
||||
|
||||
if ctx.agent.is_none() {
|
||||
bail!("No active agent");
|
||||
}
|
||||
|
||||
match action {
|
||||
"init" => {
|
||||
let goal = args.get("goal").and_then(Value::as_str).unwrap_or_default();
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut();
|
||||
match agent {
|
||||
Some(agent) => {
|
||||
agent.init_todo_list(goal);
|
||||
Ok(json!({"status": "ok", "message": "Initialized new todo list"}))
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
}
|
||||
ctx.init_todo_list(goal);
|
||||
Ok(json!({"status": "ok", "message": "Initialized new todo list"}))
|
||||
}
|
||||
"add" => {
|
||||
let task = args.get("task").and_then(Value::as_str).unwrap_or_default();
|
||||
if task.is_empty() {
|
||||
return Ok(json!({"error": "task description is required"}));
|
||||
}
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut();
|
||||
match agent {
|
||||
Some(agent) => {
|
||||
let id = agent.add_todo(task);
|
||||
Ok(json!({"status": "ok", "id": id}))
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
}
|
||||
let id = ctx.add_todo(task);
|
||||
Ok(json!({"status": "ok", "id": id}))
|
||||
}
|
||||
"done" => {
|
||||
let id = args
|
||||
@@ -132,50 +122,26 @@ pub fn handle_todo_tool(config: &GlobalConfig, cmd_name: &str, args: &Value) ->
|
||||
.map(|v| v as usize);
|
||||
match id {
|
||||
Some(id) => {
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut();
|
||||
match agent {
|
||||
Some(agent) => {
|
||||
if agent.mark_todo_done(id) {
|
||||
Ok(
|
||||
json!({"status": "ok", "message": format!("Marked todo {id} as done")}),
|
||||
)
|
||||
} else {
|
||||
Ok(json!({"error": format!("Todo {id} not found")}))
|
||||
}
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
if ctx.mark_todo_done(id) {
|
||||
Ok(json!({"status": "ok", "message": format!("Marked todo {id} as done")}))
|
||||
} else {
|
||||
Ok(json!({"error": format!("Todo {id} not found")}))
|
||||
}
|
||||
}
|
||||
None => Ok(json!({"error": "id is required and must be a number"})),
|
||||
}
|
||||
}
|
||||
"list" => {
|
||||
let cfg = config.read();
|
||||
let agent = cfg.agent.as_ref();
|
||||
match agent {
|
||||
Some(agent) => {
|
||||
let list = agent.todo_list();
|
||||
if list.is_empty() {
|
||||
Ok(json!({"goal": "", "todos": []}))
|
||||
} else {
|
||||
Ok(serde_json::to_value(list)
|
||||
.unwrap_or(json!({"error": "serialization failed"})))
|
||||
}
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
let list = &ctx.todo_list;
|
||||
if list.is_empty() {
|
||||
Ok(json!({"goal": "", "todos": []}))
|
||||
} else {
|
||||
Ok(serde_json::to_value(list).unwrap_or(json!({"error": "serialization failed"})))
|
||||
}
|
||||
}
|
||||
"clear" => {
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut();
|
||||
match agent {
|
||||
Some(agent) => {
|
||||
agent.clear_todo_list();
|
||||
Ok(json!({"status": "ok", "message": "Todo list cleared"}))
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
}
|
||||
ctx.clear_todo_list();
|
||||
Ok(json!({"status": "ok", "message": "Todo list cleared"}))
|
||||
}
|
||||
_ => bail!("Unknown todo action: {action}"),
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::{FunctionDeclaration, JsonSchema};
|
||||
use crate::config::GlobalConfig;
|
||||
use crate::config::RequestContext;
|
||||
use crate::supervisor::escalation::{EscalationRequest, new_escalation_id};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
@@ -120,7 +120,7 @@ pub fn user_interaction_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
}
|
||||
|
||||
pub async fn handle_user_tool(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
cmd_name: &str,
|
||||
args: &Value,
|
||||
) -> Result<Value> {
|
||||
@@ -128,12 +128,12 @@ pub async fn handle_user_tool(
|
||||
.strip_prefix(USER_FUNCTION_PREFIX)
|
||||
.unwrap_or(cmd_name);
|
||||
|
||||
let depth = config.read().current_depth;
|
||||
let depth = ctx.current_depth;
|
||||
|
||||
if depth == 0 {
|
||||
handle_direct(action, args)
|
||||
} else {
|
||||
handle_escalated(config, action, args).await
|
||||
handle_escalated(ctx, action, args).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ fn handle_direct_checkbox(args: &Value) -> Result<Value> {
|
||||
Ok(json!({ "answers": answers }))
|
||||
}
|
||||
|
||||
async fn handle_escalated(config: &GlobalConfig, action: &str, args: &Value) -> Result<Value> {
|
||||
async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> Result<Value> {
|
||||
let question = args
|
||||
.get("question")
|
||||
.and_then(Value::as_str)
|
||||
@@ -212,28 +212,24 @@ async fn handle_escalated(config: &GlobalConfig, action: &str, args: &Value) ->
|
||||
.collect()
|
||||
});
|
||||
|
||||
let (from_agent_id, from_agent_name, root_queue, timeout_secs) = {
|
||||
let cfg = config.read();
|
||||
let agent_id = cfg
|
||||
.self_agent_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let agent_name = cfg
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|a| a.name().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let queue = cfg
|
||||
.root_escalation_queue
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("No escalation queue available; cannot reach parent agent"))?;
|
||||
let timeout = cfg
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|a| a.escalation_timeout())
|
||||
.unwrap_or(DEFAULT_ESCALATION_TIMEOUT_SECS);
|
||||
(agent_id, agent_name, queue, timeout)
|
||||
};
|
||||
let from_agent_id = ctx
|
||||
.self_agent_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let from_agent_name = ctx
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|a| a.name().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let root_queue = ctx
|
||||
.root_escalation_queue()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No escalation queue available; cannot reach parent agent"))?;
|
||||
let timeout_secs = ctx
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|a| a.escalation_timeout())
|
||||
.unwrap_or(DEFAULT_ESCALATION_TIMEOUT_SECS);
|
||||
|
||||
let escalation_id = new_escalation_id();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
+128
-105
@@ -15,19 +15,19 @@ mod vault;
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
use crate::cli::Cli;
|
||||
use crate::client::{
|
||||
ModelType, call_chat_completions, call_chat_completions_streaming, list_models, oauth,
|
||||
};
|
||||
use crate::config::paths;
|
||||
use crate::config::{
|
||||
Agent, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, GlobalConfig, Input, SHELL_ROLE,
|
||||
TEMP_SESSION_NAME, WorkingMode, ensure_parent_exists, list_agents, load_env_file,
|
||||
Agent, AppConfig, AppState, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, Input, RequestContext,
|
||||
SHELL_ROLE, TEMP_SESSION_NAME, WorkingMode, ensure_parent_exists, list_agents, load_env_file,
|
||||
macro_execute,
|
||||
};
|
||||
use crate::render::{prompt_theme, render_error};
|
||||
use crate::repl::Repl;
|
||||
use crate::utils::*;
|
||||
|
||||
use crate::cli::Cli;
|
||||
use crate::vault::Vault;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use clap::{CommandFactory, Parser};
|
||||
@@ -40,9 +40,8 @@ use log4rs::append::file::FileAppender;
|
||||
use log4rs::config::{Appender, Logger, Root};
|
||||
use log4rs::encode::pattern::PatternEncoder;
|
||||
use oauth::OAuthProvider;
|
||||
use parking_lot::RwLock;
|
||||
use std::path::PathBuf;
|
||||
use std::{env, mem, process, sync::Arc};
|
||||
use std::{env, process, sync::Arc};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -96,50 +95,71 @@ async fn main() -> Result<()> {
|
||||
|
||||
let abort_signal = create_abort_signal();
|
||||
let start_mcp_servers = cli.agent.is_none() && cli.role.is_none();
|
||||
let config = Arc::new(RwLock::new(
|
||||
Config::init(
|
||||
working_mode,
|
||||
info_flag,
|
||||
start_mcp_servers,
|
||||
log_path,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?,
|
||||
));
|
||||
let cfg = Config::init(
|
||||
working_mode,
|
||||
info_flag,
|
||||
start_mcp_servers,
|
||||
log_path,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
let app_config: Arc<AppConfig> = Arc::new(cfg.to_app_config());
|
||||
let (mcp_config, mcp_log_path) = match &cfg.mcp_registry {
|
||||
Some(reg) => (reg.mcp_config().cloned(), reg.log_path().cloned()),
|
||||
None => (None, None),
|
||||
};
|
||||
let app_state: Arc<AppState> = Arc::new(AppState {
|
||||
config: app_config,
|
||||
vault: cfg.vault.clone(),
|
||||
mcp_factory: Default::default(),
|
||||
rag_cache: Default::default(),
|
||||
mcp_config,
|
||||
mcp_log_path,
|
||||
});
|
||||
let ctx = cfg.to_request_context(app_state);
|
||||
|
||||
{
|
||||
let cfg = config.read();
|
||||
if cfg.highlight {
|
||||
set_global_render_config(prompt_theme(cfg.render_options()?)?)
|
||||
let app = &*ctx.app.config;
|
||||
if app.highlight {
|
||||
set_global_render_config(prompt_theme(app.render_options()?)?)
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = run(config, cli, text, abort_signal).await {
|
||||
if let Err(err) = run(ctx, cli, text, abort_signal).await {
|
||||
render_error(err);
|
||||
process::exit(1);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_app_config(ctx: &mut RequestContext, update: impl FnOnce(&mut AppConfig)) {
|
||||
let mut app_config = (*ctx.app.config).clone();
|
||||
update(&mut app_config);
|
||||
|
||||
let mut app_state = (*ctx.app).clone();
|
||||
app_state.config = Arc::new(app_config);
|
||||
ctx.app = Arc::new(app_state);
|
||||
}
|
||||
|
||||
async fn run(
|
||||
config: GlobalConfig,
|
||||
mut ctx: RequestContext,
|
||||
cli: Cli,
|
||||
text: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
if cli.sync_models {
|
||||
let url = config.read().sync_models_url();
|
||||
let url = ctx.app.config.sync_models_url();
|
||||
return Config::sync_models(&url, abort_signal.clone()).await;
|
||||
}
|
||||
|
||||
if cli.list_models {
|
||||
for model in list_models(&config.read(), ModelType::Chat) {
|
||||
for model in list_models(ctx.app.config.as_ref(), ModelType::Chat) {
|
||||
println!("{}", model.id());
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_roles {
|
||||
let roles = Config::list_roles(true).join("\n");
|
||||
let roles = paths::list_roles(true).join("\n");
|
||||
println!("{roles}");
|
||||
return Ok(());
|
||||
}
|
||||
@@ -149,24 +169,32 @@ async fn run(
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_rags {
|
||||
let rags = Config::list_rags().join("\n");
|
||||
let rags = paths::list_rags().join("\n");
|
||||
println!("{rags}");
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_macros {
|
||||
let macros = Config::list_macros().join("\n");
|
||||
let macros = paths::list_macros().join("\n");
|
||||
println!("{macros}");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if cli.dry_run {
|
||||
config.write().dry_run = true;
|
||||
update_app_config(&mut ctx, |app| app.dry_run = true);
|
||||
}
|
||||
|
||||
if let Some(agent) = &cli.agent {
|
||||
if cli.build_tools {
|
||||
info!("Building tools for agent '{agent}'...");
|
||||
Agent::init(&config, agent, abort_signal.clone()).await?;
|
||||
Agent::init(
|
||||
&ctx.app.config,
|
||||
&ctx.app,
|
||||
&ctx.model,
|
||||
ctx.info_flag,
|
||||
agent,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -175,37 +203,40 @@ async fn run(
|
||||
None => TEMP_SESSION_NAME,
|
||||
});
|
||||
if !cli.agent_variable.is_empty() {
|
||||
config.write().agent_variables = Some(
|
||||
ctx.agent_variables = Some(
|
||||
cli.agent_variable
|
||||
.chunks(2)
|
||||
.map(|v| (v[0].to_string(), v[1].to_string()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
let ret = Config::use_agent(&config, agent, session, abort_signal.clone()).await;
|
||||
config.write().agent_variables = None;
|
||||
ret?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.use_agent(app.as_ref(), agent, session, abort_signal.clone())
|
||||
.await?;
|
||||
} else {
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
if let Some(prompt) = &cli.prompt {
|
||||
config.write().use_prompt(prompt)?;
|
||||
ctx.use_prompt(app.as_ref(), prompt)?;
|
||||
} else if let Some(name) = &cli.role {
|
||||
Config::use_role_safely(&config, name, abort_signal.clone()).await?;
|
||||
ctx.use_role(app.as_ref(), name, abort_signal.clone())
|
||||
.await?;
|
||||
} else if cli.execute {
|
||||
Config::use_role_safely(&config, SHELL_ROLE, abort_signal.clone()).await?;
|
||||
ctx.use_role(app.as_ref(), SHELL_ROLE, abort_signal.clone())
|
||||
.await?;
|
||||
} else if cli.code {
|
||||
Config::use_role_safely(&config, CODE_ROLE, abort_signal.clone()).await?;
|
||||
ctx.use_role(app.as_ref(), CODE_ROLE, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
if let Some(session) = &cli.session {
|
||||
Config::use_session_safely(
|
||||
&config,
|
||||
ctx.use_session(
|
||||
app.as_ref(),
|
||||
session.as_ref().map(|v| v.as_str()),
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
if let Some(rag) = &cli.rag {
|
||||
Config::use_rag(&config, Some(rag), abort_signal.clone()).await?;
|
||||
ctx.use_rag(Some(rag), abort_signal.clone()).await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,106 +245,96 @@ async fn run(
|
||||
}
|
||||
|
||||
if cli.list_sessions {
|
||||
let sessions = config.read().list_sessions().join("\n");
|
||||
let sessions = ctx.list_sessions().join("\n");
|
||||
println!("{sessions}");
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(model_id) = &cli.model {
|
||||
config.write().set_model(model_id)?;
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
ctx.set_model_on_role_like(app.as_ref(), model_id)?;
|
||||
}
|
||||
if cli.no_stream {
|
||||
config.write().stream = false;
|
||||
update_app_config(&mut ctx, |app| app.stream = false);
|
||||
}
|
||||
if cli.empty_session {
|
||||
config.write().empty_session()?;
|
||||
ctx.empty_session()?;
|
||||
}
|
||||
if cli.save_session {
|
||||
config.write().set_save_session_this_time()?;
|
||||
ctx.set_save_session_this_time()?;
|
||||
}
|
||||
if cli.info {
|
||||
let info = config.read().info()?;
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
let info = ctx.info(app.as_ref())?;
|
||||
println!("{info}");
|
||||
return Ok(());
|
||||
}
|
||||
let is_repl = config.read().working_mode.is_repl();
|
||||
let is_repl = ctx.working_mode.is_repl();
|
||||
if cli.rebuild_rag {
|
||||
Config::rebuild_rag(&config, abort_signal.clone()).await?;
|
||||
ctx.rebuild_rag(abort_signal.clone()).await?;
|
||||
if is_repl {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
if let Some(name) = &cli.macro_name {
|
||||
macro_execute(&config, name, text.as_deref(), abort_signal.clone()).await?;
|
||||
macro_execute(&mut ctx, name, text.as_deref(), abort_signal.clone()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
if cli.execute && !is_repl {
|
||||
let input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
|
||||
shell_execute(&config, &SHELL, input, abort_signal.clone()).await?;
|
||||
let input = create_input(&ctx, text, &cli.file, abort_signal.clone()).await?;
|
||||
shell_execute(&mut ctx, &SHELL, input, abort_signal.clone()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
apply_prelude_safely(&config, abort_signal.clone()).await?;
|
||||
{
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
ctx.apply_prelude(app.as_ref(), abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
|
||||
match is_repl {
|
||||
false => {
|
||||
let mut input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
|
||||
let mut input = create_input(&ctx, text, &cli.file, abort_signal.clone()).await?;
|
||||
input.use_embeddings(abort_signal.clone()).await?;
|
||||
start_directive(&config, input, cli.code, abort_signal).await
|
||||
start_directive(&mut ctx, input, cli.code, abort_signal).await
|
||||
}
|
||||
true => {
|
||||
if !*IS_STDOUT_TERMINAL {
|
||||
bail!("No TTY for REPL")
|
||||
}
|
||||
start_interactive(&config).await
|
||||
start_interactive(ctx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_prelude_safely(config: &RwLock<Config>, abort_signal: AbortSignal) -> Result<()> {
|
||||
let mut cfg = {
|
||||
let mut guard = config.write();
|
||||
mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
cfg.apply_prelude(abort_signal.clone()).await?;
|
||||
|
||||
{
|
||||
let mut guard = config.write();
|
||||
*guard = cfg;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn start_directive(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
input: Input,
|
||||
code_mode: bool,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
let client = input.create_client()?;
|
||||
let extract_code = !*IS_STDOUT_TERMINAL && code_mode;
|
||||
config.write().before_chat_completion(&input)?;
|
||||
ctx.before_chat_completion(&input)?;
|
||||
let (output, tool_results) = if !input.stream() || extract_code {
|
||||
call_chat_completions(
|
||||
&input,
|
||||
true,
|
||||
extract_code,
|
||||
client.as_ref(),
|
||||
ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
|
||||
call_chat_completions_streaming(&input, client.as_ref(), ctx, abort_signal.clone()).await?
|
||||
};
|
||||
config
|
||||
.write()
|
||||
.after_chat_completion(&input, &output, &tool_results)?;
|
||||
ctx.after_chat_completion(app.as_ref(), &input, &output, &tool_results)?;
|
||||
|
||||
if !tool_results.is_empty() {
|
||||
start_directive(
|
||||
config,
|
||||
ctx,
|
||||
input.merge_tool_results(output, tool_results),
|
||||
code_mode,
|
||||
abort_signal,
|
||||
@@ -321,35 +342,41 @@ async fn start_directive(
|
||||
.await?;
|
||||
}
|
||||
|
||||
config.write().exit_session()?;
|
||||
ctx.exit_session()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_interactive(config: &GlobalConfig) -> Result<()> {
|
||||
let mut repl: Repl = Repl::init(config)?;
|
||||
async fn start_interactive(ctx: RequestContext) -> Result<()> {
|
||||
let mut repl: Repl = Repl::init(ctx)?;
|
||||
repl.run().await
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn shell_execute(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
shell: &Shell,
|
||||
mut input: Input,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let app: Arc<AppConfig> = Arc::clone(&ctx.app.config);
|
||||
let client = input.create_client()?;
|
||||
config.write().before_chat_completion(&input)?;
|
||||
let (eval_str, _) =
|
||||
call_chat_completions(&input, false, true, client.as_ref(), abort_signal.clone()).await?;
|
||||
ctx.before_chat_completion(&input)?;
|
||||
let (eval_str, _) = call_chat_completions(
|
||||
&input,
|
||||
false,
|
||||
true,
|
||||
client.as_ref(),
|
||||
ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
config
|
||||
.write()
|
||||
.after_chat_completion(&input, &eval_str, &[])?;
|
||||
ctx.after_chat_completion(app.as_ref(), &input, &eval_str, &[])?;
|
||||
if eval_str.is_empty() {
|
||||
bail!("No command generated");
|
||||
}
|
||||
if config.read().dry_run {
|
||||
config.read().print_markdown(&eval_str)?;
|
||||
if app.dry_run {
|
||||
app.print_markdown(&eval_str)?;
|
||||
return Ok(());
|
||||
}
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
@@ -370,7 +397,7 @@ async fn shell_execute(
|
||||
'e' => {
|
||||
debug!("{} {:?}", shell.cmd, &[&shell.arg, &eval_str]);
|
||||
let code = run_command(&shell.cmd, &[&shell.arg, &eval_str], None)?;
|
||||
if code == 0 && config.read().save_shell_history {
|
||||
if code == 0 && app.save_shell_history {
|
||||
let _ = append_to_shell_history(&shell.name, &eval_str, code);
|
||||
}
|
||||
process::exit(code);
|
||||
@@ -379,15 +406,16 @@ async fn shell_execute(
|
||||
let revision = Text::new("Enter your revision:").prompt()?;
|
||||
let text = format!("{}\n{revision}", input.text());
|
||||
input.set_text(text);
|
||||
return shell_execute(config, shell, input, abort_signal.clone()).await;
|
||||
return shell_execute(ctx, shell, input, abort_signal.clone()).await;
|
||||
}
|
||||
'd' => {
|
||||
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
|
||||
let input = Input::from_str(config, &eval_str, Some(role));
|
||||
let role = ctx.retrieve_role(app.as_ref(), EXPLAIN_SHELL_ROLE)?;
|
||||
let input = Input::from_str(ctx, &eval_str, Some(role));
|
||||
if input.stream() {
|
||||
call_chat_completions_streaming(
|
||||
&input,
|
||||
client.as_ref(),
|
||||
ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -397,6 +425,7 @@ async fn shell_execute(
|
||||
true,
|
||||
false,
|
||||
client.as_ref(),
|
||||
ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -419,22 +448,16 @@ async fn shell_execute(
|
||||
}
|
||||
|
||||
async fn create_input(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
text: Option<String>,
|
||||
file: &[String],
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Input> {
|
||||
let text = text.unwrap_or_default();
|
||||
let input = if file.is_empty() {
|
||||
Input::from_str(config, &text.unwrap_or_default(), None)
|
||||
Input::from_str(ctx, &text, None)
|
||||
} else {
|
||||
Input::from_files_with_spinner(
|
||||
config,
|
||||
&text.unwrap_or_default(),
|
||||
file.to_vec(),
|
||||
None,
|
||||
abort_signal,
|
||||
)
|
||||
.await?
|
||||
Input::from_files_with_spinner(ctx, &text, file.to_vec(), None, abort_signal).await?
|
||||
};
|
||||
if input.is_empty() {
|
||||
bail!("No input");
|
||||
@@ -443,7 +466,7 @@ async fn create_input(
|
||||
}
|
||||
|
||||
fn setup_logger() -> Result<Option<PathBuf>> {
|
||||
let (log_level, log_path) = Config::log_config()?;
|
||||
let (log_level, log_path) = paths::log_config()?;
|
||||
if log_level == LevelFilter::Off {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
+73
-208
@@ -1,21 +1,17 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::paths;
|
||||
use crate::utils::{AbortSignal, abortable_run_with_spinner};
|
||||
use crate::vault::interpolate_secrets;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use bm25::{Document, Language, SearchEngine, SearchEngineBuilder};
|
||||
use futures_util::future::BoxFuture;
|
||||
use futures_util::{StreamExt, TryStreamExt, stream};
|
||||
use indoc::formatdoc;
|
||||
use rmcp::model::{CallToolRequestParams, CallToolResult};
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::transport::TokioChildProcess;
|
||||
use rmcp::{RoleClient, ServiceExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs::OpenOptions;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use tokio::process::Command;
|
||||
@@ -24,7 +20,7 @@ pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
||||
pub const MCP_SEARCH_META_FUNCTION_NAME_PREFIX: &str = "mcp_search";
|
||||
pub const MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX: &str = "mcp_describe";
|
||||
|
||||
type ConnectedServer = RunningService<RoleClient, ()>;
|
||||
pub type ConnectedServer = RunningService<RoleClient, ()>;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize)]
|
||||
pub struct CatalogItem {
|
||||
@@ -35,49 +31,34 @@ pub struct CatalogItem {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ServerCatalog {
|
||||
engine: SearchEngine<String>,
|
||||
items: HashMap<String, CatalogItem>,
|
||||
}
|
||||
|
||||
impl ServerCatalog {
|
||||
pub fn build_bm25(items: &HashMap<String, CatalogItem>) -> SearchEngine<String> {
|
||||
let docs = items.values().map(|it| {
|
||||
let contents = format!("{}\n{}\nserver:{}", it.name, it.description, it.server);
|
||||
Document {
|
||||
id: it.name.clone(),
|
||||
contents,
|
||||
}
|
||||
});
|
||||
SearchEngineBuilder::<String>::with_documents(Language::English, docs).build()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ServerCatalog {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
engine: Self::build_bm25(&self.items),
|
||||
items: self.items.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServersConfig {
|
||||
pub(crate) struct McpServersConfig {
|
||||
#[serde(rename = "mcpServers")]
|
||||
mcp_servers: HashMap<String, McpServer>,
|
||||
pub mcp_servers: HashMap<String, McpServer>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServer {
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, JsonField>>,
|
||||
cwd: Option<String>,
|
||||
pub(crate) struct McpServer {
|
||||
pub command: String,
|
||||
pub args: Option<Vec<String>>,
|
||||
pub env: Option<HashMap<String, JsonField>>,
|
||||
pub cwd: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum JsonField {
|
||||
pub(crate) enum JsonField {
|
||||
Str(String),
|
||||
Bool(bool),
|
||||
Int(i64),
|
||||
@@ -103,25 +84,25 @@ impl McpRegistry {
|
||||
log_path,
|
||||
..Default::default()
|
||||
};
|
||||
if !Config::mcp_config_file().try_exists().with_context(|| {
|
||||
if !paths::mcp_config_file().try_exists().with_context(|| {
|
||||
format!(
|
||||
"Failed to check MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
paths::mcp_config_file().display()
|
||||
)
|
||||
})? {
|
||||
debug!(
|
||||
"MCP config file does not exist at {}, skipping MCP initialization",
|
||||
Config::mcp_config_file().display()
|
||||
paths::mcp_config_file().display()
|
||||
);
|
||||
return Ok(registry);
|
||||
}
|
||||
let err = || {
|
||||
format!(
|
||||
"Failed to load MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
paths::mcp_config_file().display()
|
||||
)
|
||||
};
|
||||
let content = tokio::fs::read_to_string(Config::mcp_config_file())
|
||||
let content = tokio::fs::read_to_string(paths::mcp_config_file())
|
||||
.await
|
||||
.with_context(err)?;
|
||||
|
||||
@@ -157,34 +138,6 @@ impl McpRegistry {
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
pub async fn reinit(
|
||||
mut registry: McpRegistry,
|
||||
enabled_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
debug!("Reinitializing MCP registry");
|
||||
|
||||
let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
|
||||
let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
|
||||
|
||||
debug!("Stopping unused MCP servers");
|
||||
abortable_run_with_spinner(
|
||||
registry.stop_unused_servers(&desired_set),
|
||||
"Stopping unused MCP servers",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
abortable_run_with_spinner(
|
||||
registry.start_select_mcp_servers(enabled_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
async fn start_select_mcp_servers(
|
||||
&mut self,
|
||||
enabled_mcp_servers: Option<String>,
|
||||
@@ -229,48 +182,14 @@ impl McpRegistry {
|
||||
&self,
|
||||
id: String,
|
||||
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> {
|
||||
let server = self
|
||||
let spec = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
let mut cmd = Command::new(&server.command);
|
||||
if let Some(args) = &server.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
if let Some(env) = &server.env {
|
||||
let env: HashMap<String, String> = env
|
||||
.iter()
|
||||
.map(|(k, v)| match v {
|
||||
JsonField::Str(s) => (k.clone(), s.clone()),
|
||||
JsonField::Bool(b) => (k.clone(), b.to_string()),
|
||||
JsonField::Int(i) => (k.clone(), i.to_string()),
|
||||
})
|
||||
.collect();
|
||||
cmd.envs(env);
|
||||
}
|
||||
if let Some(cwd) = &server.cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
|
||||
let transport = if let Some(log_path) = self.log_path.as_ref() {
|
||||
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
|
||||
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?;
|
||||
|
||||
let log_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_path)?;
|
||||
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
|
||||
transport
|
||||
} else {
|
||||
TokioChildProcess::new(cmd)?
|
||||
};
|
||||
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
|
||||
);
|
||||
let tools = service.list_tools(None).await?;
|
||||
debug!("Available tools for MCP server {id}: {tools:?}");
|
||||
|
||||
@@ -290,10 +209,7 @@ impl McpRegistry {
|
||||
items_map.insert(it.name.clone(), it);
|
||||
});
|
||||
|
||||
let catalog = ServerCatalog {
|
||||
engine: ServerCatalog::build_bm25(&items_map),
|
||||
items: items_map,
|
||||
};
|
||||
let catalog = ServerCatalog { items: items_map };
|
||||
|
||||
info!("Started MCP server: {id}");
|
||||
|
||||
@@ -321,118 +237,67 @@ impl McpRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
for (id, _) in self.servers.iter() {
|
||||
if !keep_ids.contains(id) {
|
||||
ids_to_remove.push(id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for id in ids_to_remove {
|
||||
if let Some(server) = self.servers.remove(&id) {
|
||||
match Arc::try_unwrap(server) {
|
||||
Ok(server_inner) => {
|
||||
server_inner
|
||||
.cancel()
|
||||
.await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
}
|
||||
Err(_) => {
|
||||
info!("Detaching from MCP server: {id} (still in use)");
|
||||
}
|
||||
}
|
||||
self.catalogs.remove(&id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
pub fn running_servers(&self) -> &HashMap<String, Arc<ConnectedServer>> {
|
||||
&self.servers
|
||||
}
|
||||
|
||||
pub fn list_started_servers(&self) -> Vec<String> {
|
||||
self.servers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn list_configured_servers(&self) -> Vec<String> {
|
||||
if let Some(config) = &self.config {
|
||||
config.mcp_servers.keys().cloned().collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn search_tools_server(&self, server: &str, query: &str, top_k: usize) -> Vec<CatalogItem> {
|
||||
let Some(catalog) = self.catalogs.get(server) else {
|
||||
return vec![];
|
||||
};
|
||||
let engine = &catalog.engine;
|
||||
let raw = engine.search(query, top_k.min(20));
|
||||
|
||||
raw.into_iter()
|
||||
.filter_map(|r| catalog.items.get(&r.document.id))
|
||||
.take(top_k)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn describe(&self, server_id: &str, tool: &str) -> Result<Value> {
|
||||
let server = self
|
||||
.servers
|
||||
.iter()
|
||||
.filter(|(id, _)| &server_id == id)
|
||||
.map(|(_, s)| s.clone())
|
||||
.next()
|
||||
.ok_or(anyhow!("{server_id} MCP server not found in config"))?;
|
||||
|
||||
let tool_schema = server
|
||||
.list_tools(None)
|
||||
.await?
|
||||
.tools
|
||||
.into_iter()
|
||||
.find(|it| it.name == tool)
|
||||
.ok_or(anyhow!(
|
||||
"{tool} not found in {server_id} MCP server catalog"
|
||||
))?
|
||||
.input_schema;
|
||||
Ok(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool": {
|
||||
"type": "string",
|
||||
},
|
||||
"arguments": tool_schema
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn invoke(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> BoxFuture<'static, Result<CallToolResult>> {
|
||||
let server = self
|
||||
.servers
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("Invoked MCP server does not exist: {server}"));
|
||||
|
||||
let tool = tool.to_owned();
|
||||
Box::pin(async move {
|
||||
let server = server?;
|
||||
let call_tool_request = CallToolRequestParams {
|
||||
name: Cow::Owned(tool.to_owned()),
|
||||
arguments: arguments.as_object().cloned(),
|
||||
meta: None,
|
||||
task: None,
|
||||
};
|
||||
|
||||
let result = server.call_tool(call_tool_request).await?;
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
|
||||
pub fn mcp_config(&self) -> Option<&McpServersConfig> {
|
||||
self.config.as_ref()
|
||||
}
|
||||
|
||||
pub fn log_path(&self) -> Option<&PathBuf> {
|
||||
self.log_path.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn spawn_mcp_server(
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
let mut cmd = Command::new(&spec.command);
|
||||
if let Some(args) = &spec.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
if let Some(env) = &spec.env {
|
||||
let env: HashMap<String, String> = env
|
||||
.iter()
|
||||
.map(|(k, v)| match v {
|
||||
JsonField::Str(s) => (k.clone(), s.clone()),
|
||||
JsonField::Bool(b) => (k.clone(), b.to_string()),
|
||||
JsonField::Int(i) => (k.clone(), i.to_string()),
|
||||
})
|
||||
.collect();
|
||||
cmd.envs(env);
|
||||
}
|
||||
if let Some(cwd) = &spec.cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
|
||||
let transport = if let Some(log_path) = log_path {
|
||||
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
|
||||
|
||||
let log_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_path)?;
|
||||
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
|
||||
transport
|
||||
} else {
|
||||
TokioChildProcess::new(cmd)?
|
||||
};
|
||||
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &spec.command))?,
|
||||
);
|
||||
Ok(service)
|
||||
}
|
||||
|
||||
+40
-33
@@ -15,11 +15,13 @@ use inquire::{Confirm, Select, Text, required, validator::Validation};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration};
|
||||
use std::{
|
||||
collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, sync::Arc, time::Duration,
|
||||
};
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Rag {
|
||||
config: GlobalConfig,
|
||||
app_config: Arc<AppConfig>,
|
||||
name: String,
|
||||
path: String,
|
||||
embedding_model: Model,
|
||||
@@ -43,7 +45,7 @@ impl Debug for Rag {
|
||||
impl Clone for Rag {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
config: self.config.clone(),
|
||||
app_config: self.app_config.clone(),
|
||||
name: self.name.clone(),
|
||||
path: self.path.clone(),
|
||||
embedding_model: self.embedding_model.clone(),
|
||||
@@ -56,8 +58,12 @@ impl Clone for Rag {
|
||||
}
|
||||
|
||||
impl Rag {
|
||||
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.app_config, model)
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
save_path: &Path,
|
||||
doc_paths: &[String],
|
||||
@@ -67,11 +73,9 @@ impl Rag {
|
||||
bail!("Failed to init rag in non-interactive mode");
|
||||
}
|
||||
println!("⚙ Initializing RAG...");
|
||||
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(config)?;
|
||||
let (reranker_model, top_k) = {
|
||||
let config = config.read();
|
||||
(config.rag_reranker_model.clone(), config.rag_top_k)
|
||||
};
|
||||
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(app)?;
|
||||
let reranker_model = app.rag_reranker_model.clone();
|
||||
let top_k = app.rag_top_k;
|
||||
let data = RagData::new(
|
||||
embedding_model.id(),
|
||||
chunk_size,
|
||||
@@ -80,12 +84,12 @@ impl Rag {
|
||||
top_k,
|
||||
embedding_model.max_batch_size(),
|
||||
);
|
||||
let mut rag = Self::create(config, name, save_path, data)?;
|
||||
let mut rag = Self::create(app, name, save_path, data)?;
|
||||
let mut paths = doc_paths.to_vec();
|
||||
if paths.is_empty() {
|
||||
paths = add_documents()?;
|
||||
};
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = app.document_loaders.clone();
|
||||
let (spinner, spinner_rx) = Spinner::create("");
|
||||
abortable_run_with_spinner_rx(
|
||||
rag.sync_documents(&paths, true, loaders, Some(spinner)),
|
||||
@@ -99,20 +103,29 @@ impl Rag {
|
||||
Ok(rag)
|
||||
}
|
||||
|
||||
pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result<Self> {
|
||||
pub fn load(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
path: &Path,
|
||||
) -> Result<Self> {
|
||||
let err = || format!("Failed to load rag '{name}' at '{}'", path.display());
|
||||
let content = fs::read_to_string(path).with_context(err)?;
|
||||
let data: RagData = serde_yaml::from_str(&content).with_context(err)?;
|
||||
Self::create(config, name, path, data)
|
||||
Self::create(app, name, path, data)
|
||||
}
|
||||
|
||||
pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result<Self> {
|
||||
pub fn create(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
path: &Path,
|
||||
data: RagData,
|
||||
) -> Result<Self> {
|
||||
let hnsw = data.build_hnsw();
|
||||
let bm25 = data.build_bm25();
|
||||
let embedding_model =
|
||||
Model::retrieve_model(&config.read(), &data.embedding_model, ModelType::Embedding)?;
|
||||
Model::retrieve_model(app, &data.embedding_model, ModelType::Embedding)?;
|
||||
let rag = Rag {
|
||||
config: config.clone(),
|
||||
app_config: Arc::new(app.clone()),
|
||||
name: name.to_string(),
|
||||
path: path.display().to_string(),
|
||||
data,
|
||||
@@ -132,10 +145,10 @@ impl Rag {
|
||||
&mut self,
|
||||
document_paths: &[String],
|
||||
refresh: bool,
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = app.document_loaders.clone();
|
||||
let (spinner, spinner_rx) = Spinner::create("");
|
||||
abortable_run_with_spinner_rx(
|
||||
self.sync_documents(document_paths, refresh, loaders, Some(spinner)),
|
||||
@@ -149,22 +162,17 @@ impl Rag {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_config(config: &GlobalConfig) -> Result<(Model, usize, usize)> {
|
||||
let (embedding_model_id, chunk_size, chunk_overlap) = {
|
||||
let config = config.read();
|
||||
(
|
||||
config.rag_embedding_model.clone(),
|
||||
config.rag_chunk_size,
|
||||
config.rag_chunk_overlap,
|
||||
)
|
||||
};
|
||||
pub fn create_config(app: &AppConfig) -> Result<(Model, usize, usize)> {
|
||||
let embedding_model_id = app.rag_embedding_model.clone();
|
||||
let chunk_size = app.rag_chunk_size;
|
||||
let chunk_overlap = app.rag_chunk_overlap;
|
||||
let embedding_model_id = match embedding_model_id {
|
||||
Some(value) => {
|
||||
println!("Select embedding model: {value}");
|
||||
value
|
||||
}
|
||||
None => {
|
||||
let models = list_models(&config.read(), ModelType::Embedding);
|
||||
let models = list_models(app, ModelType::Embedding);
|
||||
if models.is_empty() {
|
||||
bail!("No available embedding model");
|
||||
}
|
||||
@@ -172,7 +180,7 @@ impl Rag {
|
||||
}
|
||||
};
|
||||
let embedding_model =
|
||||
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
|
||||
Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?;
|
||||
|
||||
let chunk_size = match chunk_size {
|
||||
Some(value) => {
|
||||
@@ -560,9 +568,8 @@ impl Rag {
|
||||
|
||||
let ids = match rerank_model {
|
||||
Some(model_id) => {
|
||||
let model =
|
||||
Model::retrieve_model(&self.config.read(), model_id, ModelType::Reranker)?;
|
||||
let client = init_client(&self.config, Some(model))?;
|
||||
let model = Model::retrieve_model(&self.app_config, model_id, ModelType::Reranker)?;
|
||||
let client = self.create_embeddings_client(model)?;
|
||||
let ids: IndexSet<DocumentId> = [vector_search_ids, keyword_search_ids]
|
||||
.concat()
|
||||
.into_iter()
|
||||
@@ -665,7 +672,7 @@ impl Rag {
|
||||
data: EmbeddingsData,
|
||||
spinner: Option<Spinner>,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?;
|
||||
let embedding_client = self.create_embeddings_client(self.embedding_model.clone())?;
|
||||
let EmbeddingsData { texts, query } = data;
|
||||
let batch_size = self
|
||||
.data
|
||||
|
||||
+4
-4
@@ -8,18 +8,18 @@ pub use self::markdown::{MarkdownRender, RenderOptions};
|
||||
use self::stream::{markdown_stream, raw_stream};
|
||||
|
||||
use crate::utils::{AbortSignal, IS_STDOUT_TERMINAL, error_text, pretty_error};
|
||||
use crate::{client::SseEvent, config::GlobalConfig};
|
||||
use crate::{client::SseEvent, config::AppConfig};
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
pub async fn render_stream(
|
||||
rx: UnboundedReceiver<SseEvent>,
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let ret = if *IS_STDOUT_TERMINAL && config.read().highlight {
|
||||
let render_options = config.read().render_options()?;
|
||||
let ret = if *IS_STDOUT_TERMINAL && app.highlight {
|
||||
let render_options = app.render_options()?;
|
||||
let mut render = MarkdownRender::init(render_options)?;
|
||||
markdown_stream(rx, &mut render, &abort_signal).await
|
||||
} else {
|
||||
|
||||
+14
-15
@@ -1,9 +1,11 @@
|
||||
use super::{REPL_COMMANDS, ReplCommand};
|
||||
|
||||
use crate::{config::GlobalConfig, utils::fuzzy_filter};
|
||||
use crate::{config::RequestContext, utils::fuzzy_filter};
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use reedline::{Completer, Span, Suggestion};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
impl Completer for ReplCompleter {
|
||||
fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
|
||||
@@ -27,7 +29,8 @@ impl Completer for ReplCompleter {
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
let state = self.config.read().state();
|
||||
let ctx = self.ctx.read();
|
||||
let state = ctx.state();
|
||||
|
||||
let command_filter = parts
|
||||
.iter()
|
||||
@@ -49,16 +52,12 @@ impl Completer for ReplCompleter {
|
||||
let span = Span::new(parts[parts_len - 1].1, pos);
|
||||
let args_line = &line[parts[1].1..];
|
||||
let args: Vec<&str> = parts.iter().skip(1).map(|(v, _)| *v).collect();
|
||||
suggestions.extend(
|
||||
self.config
|
||||
.read()
|
||||
.repl_complete(cmd, &args, args_line)
|
||||
.iter()
|
||||
.map(|(value, description)| {
|
||||
let description = description.as_deref().unwrap_or_default();
|
||||
create_suggestion(value, description, span)
|
||||
}),
|
||||
)
|
||||
suggestions.extend(ctx.repl_complete(cmd, &args, args_line).iter().map(
|
||||
|(value, description)| {
|
||||
let description = description.as_deref().unwrap_or_default();
|
||||
create_suggestion(value, description, span)
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
if suggestions.is_empty() {
|
||||
@@ -80,13 +79,13 @@ impl Completer for ReplCompleter {
|
||||
}
|
||||
|
||||
pub struct ReplCompleter {
|
||||
config: GlobalConfig,
|
||||
ctx: Arc<RwLock<RequestContext>>,
|
||||
commands: Vec<ReplCommand>,
|
||||
groups: HashMap<&'static str, usize>,
|
||||
}
|
||||
|
||||
impl ReplCompleter {
|
||||
pub fn new(config: &GlobalConfig) -> Self {
|
||||
pub fn new(ctx: Arc<RwLock<RequestContext>>) -> Self {
|
||||
let mut groups = HashMap::new();
|
||||
|
||||
let commands: Vec<ReplCommand> = REPL_COMMANDS.to_vec();
|
||||
@@ -97,7 +96,7 @@ impl ReplCompleter {
|
||||
}
|
||||
|
||||
Self {
|
||||
config: config.clone(),
|
||||
ctx,
|
||||
commands,
|
||||
groups,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::REPL_COMMANDS;
|
||||
|
||||
use crate::{config::GlobalConfig, utils::NO_COLOR};
|
||||
use crate::utils::NO_COLOR;
|
||||
|
||||
use nu_ansi_term::{Color, Style};
|
||||
use reedline::{Highlighter, StyledText};
|
||||
@@ -11,7 +11,7 @@ const MATCH_COLOR: Color = Color::Green;
|
||||
pub struct ReplHighlighter;
|
||||
|
||||
impl ReplHighlighter {
|
||||
pub fn new(_config: &GlobalConfig) -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
+211
-265
@@ -7,8 +7,9 @@ use self::highlighter::ReplHighlighter;
|
||||
use self::prompt::ReplPrompt;
|
||||
|
||||
use crate::client::{call_chat_completions, call_chat_completions_streaming, init_client, oauth};
|
||||
use crate::config::paths;
|
||||
use crate::config::{
|
||||
AgentVariables, AssertState, Config, GlobalConfig, Input, LastMessage, StateFlags,
|
||||
AgentVariables, AppConfig, AssertState, Input, LastMessage, RequestContext, StateFlags,
|
||||
macro_execute,
|
||||
};
|
||||
use crate::render::render_error;
|
||||
@@ -16,11 +17,11 @@ use crate::utils::{
|
||||
AbortSignal, abortable_run_with_spinner, create_abort_signal, dimmed_text, set_text, temp_file,
|
||||
};
|
||||
|
||||
use crate::mcp::McpRegistry;
|
||||
use crate::resolve_oauth_client;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use crossterm::cursor::SetCursorStyle;
|
||||
use fancy_regex::Regex;
|
||||
use parking_lot::RwLock;
|
||||
use reedline::CursorConfig;
|
||||
use reedline::{
|
||||
ColumnarMenu, EditCommand, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline,
|
||||
@@ -29,7 +30,7 @@ use reedline::{
|
||||
};
|
||||
use reedline::{MenuBuilder, Signal};
|
||||
use std::sync::LazyLock;
|
||||
use std::{env, mem, process};
|
||||
use std::{env, process, sync::Arc};
|
||||
|
||||
const MENU_NAME: &str = "completion_menu";
|
||||
|
||||
@@ -208,31 +209,31 @@ static MULTILINE_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?s)^\s*:::\s*(.*)\s*:::\s*$").unwrap());
|
||||
|
||||
pub struct Repl {
|
||||
config: GlobalConfig,
|
||||
ctx: Arc<RwLock<RequestContext>>,
|
||||
editor: Reedline,
|
||||
prompt: ReplPrompt,
|
||||
abort_signal: AbortSignal,
|
||||
}
|
||||
|
||||
impl Repl {
|
||||
pub fn init(config: &GlobalConfig) -> Result<Self> {
|
||||
let editor = Self::create_editor(config)?;
|
||||
|
||||
let prompt = ReplPrompt::new(config);
|
||||
pub fn init(ctx: RequestContext) -> Result<Self> {
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
let ctx = Arc::new(RwLock::new(ctx));
|
||||
let editor = Self::create_editor(Arc::clone(&ctx), app.as_ref())?;
|
||||
let prompt = ReplPrompt::new(Arc::clone(&ctx));
|
||||
let abort_signal = create_abort_signal();
|
||||
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
ctx,
|
||||
editor,
|
||||
prompt,
|
||||
abort_signal,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
pub async fn run(&mut self) -> Result<()> {
|
||||
if AssertState::False(StateFlags::AGENT | StateFlags::RAG)
|
||||
.assert(self.config.read().state())
|
||||
{
|
||||
if AssertState::False(StateFlags::AGENT | StateFlags::RAG).assert(self.ctx.read().state()) {
|
||||
print!(
|
||||
r#"Welcome to {} {}
|
||||
Type ".help" for additional help.
|
||||
@@ -250,7 +251,11 @@ Type ".help" for additional help.
|
||||
match sig {
|
||||
Ok(Signal::Success(line)) => {
|
||||
self.abort_signal.reset();
|
||||
match run_repl_command(&self.config, self.abort_signal.clone(), &line).await {
|
||||
let result = {
|
||||
let mut ctx = self.ctx.write();
|
||||
run_repl_command(&mut ctx, self.abort_signal.clone(), &line).await
|
||||
};
|
||||
match result {
|
||||
Ok(exit) => {
|
||||
if exit {
|
||||
break;
|
||||
@@ -273,15 +278,15 @@ Type ".help" for additional help.
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.config.write().exit_session()?;
|
||||
self.ctx.write().exit_session()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_editor(config: &GlobalConfig) -> Result<Reedline> {
|
||||
let completer = ReplCompleter::new(config);
|
||||
let highlighter = ReplHighlighter::new(config);
|
||||
fn create_editor(ctx: Arc<RwLock<RequestContext>>, app: &AppConfig) -> Result<Reedline> {
|
||||
let completer = ReplCompleter::new(Arc::clone(&ctx));
|
||||
let highlighter = ReplHighlighter::new();
|
||||
let menu = Self::create_menu();
|
||||
let edit_mode = Self::create_edit_mode(config);
|
||||
let edit_mode = Self::create_edit_mode(app);
|
||||
let cursor_config = CursorConfig {
|
||||
vi_insert: Some(SetCursorStyle::BlinkingBar),
|
||||
vi_normal: Some(SetCursorStyle::SteadyBlock),
|
||||
@@ -299,7 +304,7 @@ Type ".help" for additional help.
|
||||
.with_validator(Box::new(ReplValidator))
|
||||
.with_ansi_colors(true);
|
||||
|
||||
if let Ok(cmd) = config.read().editor() {
|
||||
if let Ok(cmd) = app.editor() {
|
||||
let temp_file = temp_file("-repl-", ".md");
|
||||
let command = process::Command::new(cmd);
|
||||
editor = editor.with_buffer_editor(command, temp_file);
|
||||
@@ -334,8 +339,8 @@ Type ".help" for additional help.
|
||||
);
|
||||
}
|
||||
|
||||
fn create_edit_mode(config: &GlobalConfig) -> Box<dyn EditMode> {
|
||||
let edit_mode: Box<dyn EditMode> = if config.read().keybindings == "vi" {
|
||||
fn create_edit_mode(app: &AppConfig) -> Box<dyn EditMode> {
|
||||
let edit_mode: Box<dyn EditMode> = if app.keybindings == "vi" {
|
||||
let mut insert_keybindings = default_vi_insert_keybindings();
|
||||
Self::extra_keybindings(&mut insert_keybindings);
|
||||
Box::new(Vi::new(insert_keybindings, default_vi_normal_keybindings()))
|
||||
@@ -389,7 +394,7 @@ impl Validator for ReplValidator {
|
||||
}
|
||||
|
||||
pub async fn run_repl_command(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
mut line: &str,
|
||||
) -> Result<bool> {
|
||||
@@ -405,66 +410,74 @@ pub async fn run_repl_command(
|
||||
}
|
||||
".info" => match args {
|
||||
Some("role") => {
|
||||
let info = config.read().role_info()?;
|
||||
let info = ctx.role_info()?;
|
||||
print!("{info}");
|
||||
}
|
||||
Some("session") => {
|
||||
let info = config.read().session_info()?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
let info = ctx.session_info(app.as_ref())?;
|
||||
print!("{info}");
|
||||
}
|
||||
Some("rag") => {
|
||||
let info = config.read().rag_info()?;
|
||||
let info = ctx.rag_info()?;
|
||||
print!("{info}");
|
||||
}
|
||||
Some("agent") => {
|
||||
let info = config.read().agent_info()?;
|
||||
let info = ctx.agent_info()?;
|
||||
print!("{info}");
|
||||
}
|
||||
Some(_) => unknown_command()?,
|
||||
None => {
|
||||
let output = config.read().sysinfo()?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
let output = ctx.sysinfo(app.as_ref())?;
|
||||
print!("{output}");
|
||||
}
|
||||
},
|
||||
".model" => match args {
|
||||
Some(name) => {
|
||||
config.write().set_model(name)?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.set_model_on_role_like(app.as_ref(), name)?;
|
||||
}
|
||||
None => println!("Usage: .model <name>"),
|
||||
},
|
||||
".authenticate" => {
|
||||
let current_model = config.read().current_model().clone();
|
||||
let client = init_client(config, Some(current_model))?;
|
||||
let current_model = ctx.current_model().clone();
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
let client = init_client(&app, current_model)?;
|
||||
if !client.supports_oauth() {
|
||||
bail!(
|
||||
"Client '{}' doesn't either support OAuth or isn't configured to use it (i.e. uses an API key instead)",
|
||||
client.name()
|
||||
);
|
||||
}
|
||||
let clients = config.read().clients.clone();
|
||||
let clients = ctx.app.config.clients.clone();
|
||||
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
|
||||
oauth::run_oauth_flow(&*provider, &client_name).await?;
|
||||
}
|
||||
".prompt" => match args {
|
||||
Some(text) => {
|
||||
config.write().use_prompt(text)?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.use_prompt(app.as_ref(), text)?;
|
||||
}
|
||||
None => println!("Usage: .prompt <text>..."),
|
||||
},
|
||||
".role" => match args {
|
||||
Some(args) => match args.split_once(['\n', ' ']) {
|
||||
Some((name, text)) => {
|
||||
let role = config.read().retrieve_role(name.trim())?;
|
||||
let input = Input::from_str(config, text, Some(role));
|
||||
ask(config, abort_signal.clone(), input, false).await?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
let role = ctx.retrieve_role(app.as_ref(), name.trim())?;
|
||||
let input = Input::from_str(ctx, text, Some(role));
|
||||
ask(ctx, abort_signal.clone(), input, false).await?;
|
||||
}
|
||||
None => {
|
||||
let name = args;
|
||||
if !Config::has_role(name) {
|
||||
config.write().new_role(name)?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
if !paths::has_role(name) {
|
||||
ctx.new_role(app.as_ref(), name)?;
|
||||
}
|
||||
|
||||
Config::use_role_safely(config, name, abort_signal.clone()).await?;
|
||||
ctx.use_role(app.as_ref(), name, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
},
|
||||
None => println!(
|
||||
@@ -474,11 +487,26 @@ pub async fn run_repl_command(
|
||||
),
|
||||
},
|
||||
".session" => {
|
||||
Config::use_session_safely(config, args, abort_signal.clone()).await?;
|
||||
Config::maybe_autoname_session(config.clone());
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.use_session(app.as_ref(), args, abort_signal.clone())
|
||||
.await?;
|
||||
if ctx.maybe_autoname_session() {
|
||||
let color = if app.light_theme() {
|
||||
nu_ansi_term::Color::LightGray
|
||||
} else {
|
||||
nu_ansi_term::Color::DarkGray
|
||||
};
|
||||
eprintln!("\n📢 {}", color.italic().paint("Autonaming the session."),);
|
||||
if let Err(err) = ctx.autoname_session(app.as_ref()).await {
|
||||
log::warn!("Failed to autonaming the session: {err}");
|
||||
}
|
||||
if let Some(session) = ctx.session.as_mut() {
|
||||
session.set_autonaming(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
".rag" => {
|
||||
Config::use_rag(config, args, abort_signal.clone()).await?;
|
||||
ctx.use_rag(args, abort_signal.clone()).await?;
|
||||
}
|
||||
".agent" => match split_first_arg(args) {
|
||||
Some((agent_name, args)) => {
|
||||
@@ -497,13 +525,11 @@ pub async fn run_repl_command(
|
||||
bail!("Some variable values are not key=value pairs");
|
||||
}
|
||||
if !variables.is_empty() {
|
||||
config.write().agent_variables = Some(variables);
|
||||
ctx.agent_variables = Some(variables.clone());
|
||||
}
|
||||
let ret =
|
||||
Config::use_agent(config, agent_name, session_name, abort_signal.clone())
|
||||
.await;
|
||||
config.write().agent_variables = None;
|
||||
ret?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.use_agent(app.as_ref(), agent_name, session_name, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
None => {
|
||||
println!(r#"Usage: .agent <agent-name> [session-name] [key=value]..."#)
|
||||
@@ -512,7 +538,7 @@ pub async fn run_repl_command(
|
||||
".starter" => match args {
|
||||
Some(id) => {
|
||||
let mut text = None;
|
||||
if let Some(agent) = config.read().agent.as_ref() {
|
||||
if let Some(agent) = ctx.agent.as_ref() {
|
||||
for (i, value) in agent.conversation_starters().iter().enumerate() {
|
||||
if (i + 1).to_string() == id {
|
||||
text = Some(value.clone());
|
||||
@@ -522,8 +548,8 @@ pub async fn run_repl_command(
|
||||
match text {
|
||||
Some(text) => {
|
||||
println!("{}", dimmed_text(&format!(">> {text}")));
|
||||
let input = Input::from_str(config, &text, None);
|
||||
ask(config, abort_signal.clone(), input, true).await?;
|
||||
let input = Input::from_str(ctx, &text, None);
|
||||
ask(ctx, abort_signal.clone(), input, true).await?;
|
||||
}
|
||||
None => {
|
||||
bail!("Invalid starter value");
|
||||
@@ -531,50 +557,43 @@ pub async fn run_repl_command(
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let banner = config.read().agent_banner()?;
|
||||
config.read().print_markdown(&banner)?;
|
||||
let banner = ctx.agent_banner()?;
|
||||
ctx.app.config.print_markdown(&banner)?;
|
||||
}
|
||||
},
|
||||
".save" => match split_first_arg(args) {
|
||||
Some(("role", name)) => {
|
||||
config.write().save_role(name)?;
|
||||
ctx.save_role(name)?;
|
||||
}
|
||||
Some(("session", name)) => {
|
||||
config.write().save_session(name)?;
|
||||
ctx.save_session(name)?;
|
||||
}
|
||||
_ => {
|
||||
println!(r#"Usage: .save <role|session> [name]"#)
|
||||
}
|
||||
},
|
||||
".edit" => {
|
||||
if config.read().macro_flag {
|
||||
if ctx.macro_flag {
|
||||
bail!("Cannot perform this operation because you are in a macro")
|
||||
}
|
||||
match args {
|
||||
Some("config") => {
|
||||
config.read().edit_config()?;
|
||||
ctx.edit_config()?;
|
||||
}
|
||||
Some("role") => {
|
||||
let mut cfg = {
|
||||
let mut guard = config.write();
|
||||
mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
cfg.edit_role(abort_signal.clone()).await?;
|
||||
|
||||
{
|
||||
let mut guard = config.write();
|
||||
*guard = cfg;
|
||||
}
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.edit_role(app.as_ref(), abort_signal.clone()).await?;
|
||||
}
|
||||
Some("session") => {
|
||||
config.write().edit_session()?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.edit_session(app.as_ref())?;
|
||||
}
|
||||
Some("rag-docs") => {
|
||||
Config::edit_rag_docs(config, abort_signal.clone()).await?;
|
||||
ctx.edit_rag_docs(abort_signal.clone()).await?;
|
||||
}
|
||||
Some("agent-config") => {
|
||||
config.write().edit_agent_config()?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.edit_agent_config(app.as_ref())?;
|
||||
}
|
||||
_ => {
|
||||
println!(r#"Usage: .edit <config|role|session|rag-docs|agent-config>"#)
|
||||
@@ -584,7 +603,7 @@ pub async fn run_repl_command(
|
||||
".compress" => match args {
|
||||
Some("session") => {
|
||||
abortable_run_with_spinner(
|
||||
Config::compress_session(config),
|
||||
ctx.compress_session(),
|
||||
"Compressing",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
@@ -597,7 +616,7 @@ pub async fn run_repl_command(
|
||||
},
|
||||
".empty" => match args {
|
||||
Some("session") => {
|
||||
config.write().empty_session()?;
|
||||
ctx.empty_session()?;
|
||||
}
|
||||
_ => {
|
||||
println!(r#"Usage: .empty session"#)
|
||||
@@ -605,7 +624,7 @@ pub async fn run_repl_command(
|
||||
},
|
||||
".rebuild" => match args {
|
||||
Some("rag") => {
|
||||
Config::rebuild_rag(config, abort_signal.clone()).await?;
|
||||
ctx.rebuild_rag(abort_signal.clone()).await?;
|
||||
}
|
||||
_ => {
|
||||
println!(r#"Usage: .rebuild rag"#)
|
||||
@@ -613,7 +632,7 @@ pub async fn run_repl_command(
|
||||
},
|
||||
".sources" => match args {
|
||||
Some("rag") => {
|
||||
let output = Config::rag_sources(config)?;
|
||||
let output = ctx.rag_sources()?;
|
||||
println!("{output}");
|
||||
}
|
||||
_ => {
|
||||
@@ -622,10 +641,11 @@ pub async fn run_repl_command(
|
||||
},
|
||||
".macro" => match split_first_arg(args) {
|
||||
Some((name, extra)) => {
|
||||
if !Config::has_macro(name) && extra.is_none() {
|
||||
config.write().new_macro(name)?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
if !paths::has_macro(name) && extra.is_none() {
|
||||
ctx.new_macro(app.as_ref(), name)?;
|
||||
} else {
|
||||
macro_execute(config, name, extra, abort_signal.clone()).await?;
|
||||
macro_execute(ctx, name, extra, abort_signal.clone()).await?;
|
||||
}
|
||||
}
|
||||
None => println!("Usage: .macro <name> <text>..."),
|
||||
@@ -634,14 +654,14 @@ pub async fn run_repl_command(
|
||||
Some(args) => {
|
||||
let (files, text) = split_args_text(args, cfg!(windows));
|
||||
let input = Input::from_files_with_spinner(
|
||||
config,
|
||||
ctx,
|
||||
text,
|
||||
files,
|
||||
None,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
ask(config, abort_signal.clone(), input, true).await?;
|
||||
ask(ctx, abort_signal.clone(), input, true).await?;
|
||||
}
|
||||
None => println!(
|
||||
r#"Usage: .file <file|dir|url|cmd|loader:resource|%%>... [-- <text>...]
|
||||
@@ -658,8 +678,7 @@ pub async fn run_repl_command(
|
||||
".continue" => {
|
||||
let LastMessage {
|
||||
mut input, output, ..
|
||||
} = match config
|
||||
.read()
|
||||
} = match ctx
|
||||
.last_message
|
||||
.as_ref()
|
||||
.filter(|v| v.continuous && !v.output.is_empty())
|
||||
@@ -669,25 +688,21 @@ pub async fn run_repl_command(
|
||||
None => bail!("Unable to continue the response"),
|
||||
};
|
||||
input.set_continue_output(&output);
|
||||
ask(config, abort_signal.clone(), input, true).await?;
|
||||
ask(ctx, abort_signal.clone(), input, true).await?;
|
||||
}
|
||||
".regenerate" => {
|
||||
let LastMessage { mut input, .. } = match config
|
||||
.read()
|
||||
.last_message
|
||||
.as_ref()
|
||||
.filter(|v| v.continuous)
|
||||
.cloned()
|
||||
{
|
||||
Some(v) => v,
|
||||
None => bail!("Unable to regenerate the response"),
|
||||
};
|
||||
input.set_regenerate();
|
||||
ask(config, abort_signal.clone(), input, true).await?;
|
||||
let LastMessage { mut input, .. } =
|
||||
match ctx.last_message.as_ref().filter(|v| v.continuous).cloned() {
|
||||
Some(v) => v,
|
||||
None => bail!("Unable to regenerate the response"),
|
||||
};
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
input.set_regenerate(ctx.extract_role(&app));
|
||||
ask(ctx, abort_signal.clone(), input, true).await?;
|
||||
}
|
||||
".set" => match args {
|
||||
Some(args) => {
|
||||
Config::update(config, args, abort_signal).await?;
|
||||
ctx.update(args, abort_signal).await?;
|
||||
}
|
||||
_ => {
|
||||
println!("Usage: .set <key> <value>...")
|
||||
@@ -695,15 +710,14 @@ pub async fn run_repl_command(
|
||||
},
|
||||
".delete" => match args {
|
||||
Some(args) => {
|
||||
Config::delete(config, args)?;
|
||||
ctx.delete(args)?;
|
||||
}
|
||||
_ => {
|
||||
println!("Usage: .delete <role|session|rag|macro|agent-data>")
|
||||
}
|
||||
},
|
||||
".copy" => {
|
||||
let output = match config
|
||||
.read()
|
||||
let output = match ctx
|
||||
.last_message
|
||||
.as_ref()
|
||||
.filter(|v| !v.output.is_empty())
|
||||
@@ -716,89 +730,29 @@ pub async fn run_repl_command(
|
||||
}
|
||||
".exit" => match args {
|
||||
Some("role") => {
|
||||
config.write().exit_role()?;
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.expect("MCP registry should exist");
|
||||
let enabled_mcp_servers = if config.read().mcp_server_support {
|
||||
config.read().enabled_mcp_servers.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let registry =
|
||||
McpRegistry::reinit(registry, enabled_mcp_servers, abort_signal.clone())
|
||||
.await?;
|
||||
if !registry.is_empty() {
|
||||
config
|
||||
.write()
|
||||
.functions
|
||||
.append_mcp_meta_functions(registry.list_started_servers());
|
||||
}
|
||||
config.write().mcp_registry = Some(registry);
|
||||
ctx.exit_role()?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.bootstrap_tools(app.as_ref(), true, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
Some("session") => {
|
||||
if config.read().agent.is_some() {
|
||||
config.write().exit_agent_session()?;
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.expect("MCP registry should exist");
|
||||
let enabled_mcp_servers = if config.read().mcp_server_support {
|
||||
config.read().enabled_mcp_servers.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let registry = McpRegistry::reinit(
|
||||
registry,
|
||||
enabled_mcp_servers,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
if !registry.is_empty() {
|
||||
config
|
||||
.write()
|
||||
.functions
|
||||
.append_mcp_meta_functions(registry.list_started_servers());
|
||||
}
|
||||
config.write().mcp_registry = Some(registry);
|
||||
if ctx.agent.is_some() {
|
||||
ctx.exit_agent_session()?;
|
||||
} else {
|
||||
config.write().exit_session()?;
|
||||
ctx.exit_session()?;
|
||||
}
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.bootstrap_tools(app.as_ref(), true, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
Some("rag") => {
|
||||
config.write().exit_rag()?;
|
||||
ctx.exit_rag()?;
|
||||
}
|
||||
Some("agent") => {
|
||||
config.write().exit_agent()?;
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.expect("MCP registry should exist");
|
||||
let enabled_mcp_servers = if config.read().mcp_server_support {
|
||||
config.read().enabled_mcp_servers.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let registry =
|
||||
McpRegistry::reinit(registry, enabled_mcp_servers, abort_signal.clone())
|
||||
.await?;
|
||||
if !registry.is_empty() {
|
||||
config
|
||||
.write()
|
||||
.functions
|
||||
.append_mcp_meta_functions(registry.list_started_servers());
|
||||
}
|
||||
config.write().mcp_registry = Some(registry);
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.exit_agent(app.as_ref())?;
|
||||
ctx.bootstrap_tools(app.as_ref(), true, abort_signal.clone())
|
||||
.await?;
|
||||
}
|
||||
Some(_) => unknown_command()?,
|
||||
None => {
|
||||
@@ -810,57 +764,59 @@ pub async fn run_repl_command(
|
||||
bail!("Use '.empty session' instead");
|
||||
}
|
||||
Some("todo") => {
|
||||
let mut cfg = config.write();
|
||||
match cfg.agent.as_mut() {
|
||||
let cleared = match ctx.agent.as_mut() {
|
||||
Some(agent) => {
|
||||
if !agent.auto_continue_enabled() {
|
||||
bail!(
|
||||
"The todo system is not enabled for this agent. Set 'auto_continue: true' in the agent's config.yaml to enable it."
|
||||
);
|
||||
}
|
||||
if agent.todo_list().is_empty() {
|
||||
if ctx.todo_list.is_empty() {
|
||||
println!("Todo list is already empty.");
|
||||
false
|
||||
} else {
|
||||
agent.clear_todo_list();
|
||||
ctx.clear_todo_list();
|
||||
println!("Todo list cleared.");
|
||||
true
|
||||
}
|
||||
}
|
||||
None => bail!("No active agent"),
|
||||
}
|
||||
};
|
||||
let _ = cleared;
|
||||
}
|
||||
_ => unknown_command()?,
|
||||
},
|
||||
".vault" => match split_first_arg(args) {
|
||||
Some(("add", name)) => {
|
||||
if let Some(name) = name {
|
||||
config.read().vault.add_secret(name)?;
|
||||
ctx.app.vault.add_secret(name)?;
|
||||
} else {
|
||||
println!("Usage: .vault add <name>");
|
||||
}
|
||||
}
|
||||
Some(("get", name)) => {
|
||||
if let Some(name) = name {
|
||||
config.read().vault.get_secret(name, true)?;
|
||||
ctx.app.vault.get_secret(name, true)?;
|
||||
} else {
|
||||
println!("Usage: .vault get <name>");
|
||||
}
|
||||
}
|
||||
Some(("update", name)) => {
|
||||
if let Some(name) = name {
|
||||
config.read().vault.update_secret(name)?;
|
||||
ctx.app.vault.update_secret(name)?;
|
||||
} else {
|
||||
println!("Usage: .vault update <name>");
|
||||
}
|
||||
}
|
||||
Some(("delete", name)) => {
|
||||
if let Some(name) = name {
|
||||
config.read().vault.delete_secret(name)?;
|
||||
ctx.app.vault.delete_secret(name)?;
|
||||
} else {
|
||||
println!("Usage: .vault delete <name>");
|
||||
}
|
||||
}
|
||||
Some(("list", _)) => {
|
||||
config.read().vault.list_secrets(true)?;
|
||||
ctx.app.vault.list_secrets(true)?;
|
||||
}
|
||||
None | Some(_) => {
|
||||
println!("Usage: .vault <add|get|update|delete|list> [name]")
|
||||
@@ -869,20 +825,13 @@ pub async fn run_repl_command(
|
||||
_ => unknown_command()?,
|
||||
},
|
||||
None => {
|
||||
if config
|
||||
.read()
|
||||
.agent
|
||||
.as_ref()
|
||||
.is_some_and(|a| a.continuation_count() > 0)
|
||||
{
|
||||
config.write().agent.as_mut().unwrap().reset_continuation();
|
||||
}
|
||||
let input = Input::from_str(config, line, None);
|
||||
ask(config, abort_signal.clone(), input, true).await?;
|
||||
reset_continuation(ctx);
|
||||
let input = Input::from_str(ctx, line, None);
|
||||
ask(ctx, abort_signal.clone(), input, true).await?;
|
||||
}
|
||||
}
|
||||
|
||||
if !config.read().macro_flag {
|
||||
if !ctx.macro_flag {
|
||||
println!();
|
||||
}
|
||||
|
||||
@@ -891,7 +840,7 @@ pub async fn run_repl_command(
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn ask(
|
||||
config: &GlobalConfig,
|
||||
ctx: &mut RequestContext,
|
||||
abort_signal: AbortSignal,
|
||||
mut input: Input,
|
||||
with_embeddings: bool,
|
||||
@@ -902,54 +851,51 @@ async fn ask(
|
||||
if with_embeddings {
|
||||
input.use_embeddings(abort_signal.clone()).await?;
|
||||
}
|
||||
while config.read().is_compressing_session() {
|
||||
while ctx.is_compressing_session() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
let client = input.create_client()?;
|
||||
config.write().before_chat_completion(&input)?;
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
ctx.before_chat_completion(&input)?;
|
||||
let (output, tool_results) = if input.stream() {
|
||||
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
|
||||
call_chat_completions_streaming(&input, client.as_ref(), ctx, abort_signal.clone()).await?
|
||||
} else {
|
||||
call_chat_completions(&input, true, false, client.as_ref(), abort_signal.clone()).await?
|
||||
call_chat_completions(
|
||||
&input,
|
||||
true,
|
||||
false,
|
||||
client.as_ref(),
|
||||
ctx,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
config
|
||||
.write()
|
||||
.after_chat_completion(&input, &output, &tool_results)?;
|
||||
ctx.after_chat_completion(app.as_ref(), &input, &output, &tool_results)?;
|
||||
if !tool_results.is_empty() {
|
||||
ask(
|
||||
config,
|
||||
ctx,
|
||||
abort_signal,
|
||||
input.merge_tool_results(output, tool_results),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let should_continue = {
|
||||
let cfg = config.read();
|
||||
if let Some(agent) = &cfg.agent {
|
||||
agent.auto_continue_enabled()
|
||||
&& agent.continuation_count() < agent.max_auto_continues()
|
||||
&& agent.todo_list().has_incomplete()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
let should_continue = agent_should_continue(ctx);
|
||||
|
||||
if should_continue {
|
||||
let full_prompt = {
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut().expect("agent checked above");
|
||||
agent.set_last_continuation_response(output.clone());
|
||||
agent.increment_continuation();
|
||||
let count = agent.continuation_count();
|
||||
let todo_state = ctx.todo_list.render_for_model();
|
||||
let remaining = ctx.todo_list.incomplete_count();
|
||||
ctx.set_last_continuation_response(output.clone());
|
||||
ctx.increment_auto_continue_count();
|
||||
let agent = ctx.agent.as_mut().expect("agent checked above");
|
||||
let count = ctx.auto_continue_count;
|
||||
let max = agent.max_auto_continues();
|
||||
|
||||
let todo_state = agent.todo_list().render_for_model();
|
||||
let remaining = agent.todo_list().incomplete_count();
|
||||
let prompt = agent.continuation_prompt();
|
||||
|
||||
let color = if cfg.light_theme() {
|
||||
let color = if app.light_theme() {
|
||||
nu_ansi_term::Color::LightGray
|
||||
} else {
|
||||
nu_ansi_term::Color::DarkGray
|
||||
@@ -963,71 +909,63 @@ async fn ask(
|
||||
|
||||
format!("{prompt}\n\n{todo_state}")
|
||||
};
|
||||
let continuation_input = Input::from_str(config, &full_prompt, None);
|
||||
ask(config, abort_signal, continuation_input, false).await
|
||||
let continuation_input = Input::from_str(ctx, &full_prompt, None);
|
||||
ask(ctx, abort_signal, continuation_input, false).await
|
||||
} else {
|
||||
if config
|
||||
.read()
|
||||
.agent
|
||||
.as_ref()
|
||||
.is_some_and(|a| a.continuation_count() > 0)
|
||||
{
|
||||
config.write().agent.as_mut().unwrap().reset_continuation();
|
||||
reset_continuation(ctx);
|
||||
if ctx.maybe_autoname_session() {
|
||||
let color = if app.light_theme() {
|
||||
nu_ansi_term::Color::LightGray
|
||||
} else {
|
||||
nu_ansi_term::Color::DarkGray
|
||||
};
|
||||
eprintln!("\n📢 {}", color.italic().paint("Autonaming the session."),);
|
||||
if let Err(err) = ctx.autoname_session(app.as_ref()).await {
|
||||
log::warn!("Failed to autonaming the session: {err}");
|
||||
}
|
||||
if let Some(session) = ctx.session.as_mut() {
|
||||
session.set_autonaming(false);
|
||||
}
|
||||
}
|
||||
Config::maybe_autoname_session(config.clone());
|
||||
|
||||
let needs_compression = {
|
||||
let cfg = config.read();
|
||||
let compression_threshold = cfg.compression_threshold;
|
||||
cfg.session
|
||||
.as_ref()
|
||||
.is_some_and(|s| s.needs_compression(compression_threshold))
|
||||
};
|
||||
let needs_compression = ctx
|
||||
.session
|
||||
.as_ref()
|
||||
.is_some_and(|s| s.needs_compression(app.compression_threshold));
|
||||
|
||||
if needs_compression {
|
||||
let agent_can_continue_after_compress = {
|
||||
let cfg = config.read();
|
||||
cfg.agent.as_ref().is_some_and(|agent| {
|
||||
agent.auto_continue_enabled()
|
||||
&& agent.continuation_count() < agent.max_auto_continues()
|
||||
&& agent.todo_list().has_incomplete()
|
||||
})
|
||||
};
|
||||
let agent_can_continue_after_compress = agent_should_continue(ctx);
|
||||
|
||||
{
|
||||
let mut cfg = config.write();
|
||||
if let Some(session) = cfg.session.as_mut() {
|
||||
session.set_compressing(true);
|
||||
}
|
||||
if let Some(session) = ctx.session.as_mut() {
|
||||
session.set_compressing(true);
|
||||
}
|
||||
|
||||
let color = if config.read().light_theme() {
|
||||
let color = if app.light_theme() {
|
||||
nu_ansi_term::Color::LightGray
|
||||
} else {
|
||||
nu_ansi_term::Color::DarkGray
|
||||
};
|
||||
eprintln!("\n📢 {}", color.italic().paint("Compressing the session."),);
|
||||
|
||||
if let Err(err) = Config::compress_session(config).await {
|
||||
if let Err(err) = ctx.compress_session().await {
|
||||
log::warn!("Failed to compress the session: {err}");
|
||||
}
|
||||
if let Some(session) = config.write().session.as_mut() {
|
||||
if let Some(session) = ctx.session.as_mut() {
|
||||
session.set_compressing(false);
|
||||
}
|
||||
|
||||
if agent_can_continue_after_compress {
|
||||
let full_prompt = {
|
||||
let mut cfg = config.write();
|
||||
let agent = cfg.agent.as_mut().expect("agent checked above");
|
||||
agent.increment_continuation();
|
||||
let count = agent.continuation_count();
|
||||
let todo_state = ctx.todo_list.render_for_model();
|
||||
let remaining = ctx.todo_list.incomplete_count();
|
||||
ctx.increment_auto_continue_count();
|
||||
let agent = ctx.agent.as_mut().expect("agent checked above");
|
||||
let count = ctx.auto_continue_count;
|
||||
let max = agent.max_auto_continues();
|
||||
|
||||
let todo_state = agent.todo_list().render_for_model();
|
||||
let remaining = agent.todo_list().incomplete_count();
|
||||
let prompt = agent.continuation_prompt();
|
||||
|
||||
let color = if cfg.light_theme() {
|
||||
let color = if app.light_theme() {
|
||||
nu_ansi_term::Color::LightGray
|
||||
} else {
|
||||
nu_ansi_term::Color::DarkGray
|
||||
@@ -1041,11 +979,9 @@ async fn ask(
|
||||
|
||||
format!("{prompt}\n\n{todo_state}")
|
||||
};
|
||||
let continuation_input = Input::from_str(config, &full_prompt, None);
|
||||
return ask(config, abort_signal, continuation_input, false).await;
|
||||
let continuation_input = Input::from_str(ctx, &full_prompt, None);
|
||||
return ask(ctx, abort_signal, continuation_input, false).await;
|
||||
}
|
||||
} else {
|
||||
Config::maybe_compress_session(config.clone());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1053,6 +989,16 @@ async fn ask(
|
||||
}
|
||||
}
|
||||
|
||||
fn agent_should_continue(ctx: &RequestContext) -> bool {
|
||||
ctx.agent.as_ref().is_some_and(|agent| {
|
||||
agent.auto_continue_enabled() && ctx.auto_continue_count < agent.max_auto_continues()
|
||||
}) && ctx.todo_list.has_incomplete()
|
||||
}
|
||||
|
||||
fn reset_continuation(ctx: &mut RequestContext) {
|
||||
ctx.reset_continuation_count();
|
||||
}
|
||||
|
||||
fn unknown_command() -> Result<()> {
|
||||
bail!(r#"Unknown command. Type ".help" for additional help."#);
|
||||
}
|
||||
|
||||
+10
-8
@@ -1,28 +1,30 @@
|
||||
use crate::config::GlobalConfig;
|
||||
use crate::config::RequestContext;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReplPrompt {
|
||||
config: GlobalConfig,
|
||||
ctx: Arc<RwLock<RequestContext>>,
|
||||
}
|
||||
|
||||
impl ReplPrompt {
|
||||
pub fn new(config: &GlobalConfig) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
}
|
||||
pub fn new(ctx: Arc<RwLock<RequestContext>>) -> Self {
|
||||
Self { ctx }
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt for ReplPrompt {
|
||||
fn render_prompt_left(&self) -> Cow<'_, str> {
|
||||
Cow::Owned(self.config.read().render_prompt_left())
|
||||
let ctx = self.ctx.read();
|
||||
Cow::Owned(ctx.render_prompt_left(ctx.app.config.as_ref()))
|
||||
}
|
||||
|
||||
fn render_prompt_right(&self) -> Cow<'_, str> {
|
||||
Cow::Owned(self.config.read().render_prompt_right())
|
||||
let ctx = self.ctx.read();
|
||||
Cow::Owned(ctx.render_prompt_right(ctx.app.config.as_ref()))
|
||||
}
|
||||
|
||||
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<'_, str> {
|
||||
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::paths;
|
||||
use colored::Colorize;
|
||||
use fancy_regex::Regex;
|
||||
use std::fs::File;
|
||||
@@ -7,7 +7,7 @@ use std::process;
|
||||
|
||||
pub async fn tail_logs(no_color: bool) {
|
||||
let re = Regex::new(r"^(?P<timestamp>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\s+<(?P<opid>[^\s>]+)>\s+\[(?P<level>[A-Z]+)\]\s+(?P<logger>[^:]+):(?P<line>\d+)\s+-\s+(?P<message>.*)$").unwrap();
|
||||
let file_path = Config::log_path();
|
||||
let file_path = paths::log_path();
|
||||
let file = File::open(&file_path).expect("Cannot open file");
|
||||
let mut reader = BufReader::new(file);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user