feat: improved UX for parallel graph execution
This commit is contained in:
@@ -3,7 +3,6 @@ use super::structured;
|
||||
use super::types::AgentNode;
|
||||
use crate::config::RequestContext;
|
||||
use crate::function::supervisor::run_agent_for_graph;
|
||||
use crate::utils::dimmed_text;
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
@@ -24,14 +23,6 @@ impl AgentNodeExecutor {
|
||||
.interpolate(&node.prompt)
|
||||
.with_context(|| format!("Failed to interpolate prompt for agent '{}'", node.agent))?;
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!("▸ spawning agent '{}' with prompt:", node.agent))
|
||||
);
|
||||
for line in indent_prompt(&prompt, 6) {
|
||||
eprintln!("{}", dimmed_text(&line));
|
||||
}
|
||||
|
||||
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
|
||||
|
||||
let raw = timeout(
|
||||
@@ -66,21 +57,6 @@ impl AgentNodeExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec<String> {
|
||||
const MAX_LINES: usize = 12;
|
||||
let pad = " ".repeat(prefix_spaces);
|
||||
let mut out: Vec<String> = prompt
|
||||
.lines()
|
||||
.take(MAX_LINES)
|
||||
.map(|line| format!("{pad}{line}"))
|
||||
.collect();
|
||||
let total = prompt.lines().count();
|
||||
if total > MAX_LINES {
|
||||
out.push(format!("{pad}... ({} more lines)", total - MAX_LINES));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) {
|
||||
if node.output_schema.is_some()
|
||||
&& let Some(obj) = output.as_object()
|
||||
|
||||
+19
-15
@@ -1,6 +1,6 @@
|
||||
use super::agent::AgentNodeExecutor;
|
||||
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
|
||||
use super::logging::GraphLogger;
|
||||
use super::logging::{GraphLogger, node_type_label};
|
||||
use super::map::MapNodeExecutor;
|
||||
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
||||
use super::rag::RagNodeExecutor;
|
||||
@@ -146,11 +146,12 @@ impl GraphExecutor {
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrency));
|
||||
|
||||
let frontier_size = frontier.len();
|
||||
let progress_tracker = if frontier_size > 1 {
|
||||
Some(BranchProgressTracker::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let has_progress_nodes = frontier.iter().any(|nid| {
|
||||
graph.get_node(nid).is_some_and(|n| {
|
||||
!matches!(n.node_type, NodeType::Approval(_) | NodeType::Input(_))
|
||||
})
|
||||
});
|
||||
let progress_tracker = has_progress_nodes.then(BranchProgressTracker::new);
|
||||
let mut branch_tasks = Vec::with_capacity(frontier_size);
|
||||
for node_id in &frontier {
|
||||
let node = graph
|
||||
@@ -161,19 +162,24 @@ impl GraphExecutor {
|
||||
.clone();
|
||||
let branch_state = state.fork_for_branch_state();
|
||||
let mut branch_ctx = ctx.fork_for_branch();
|
||||
if frontier_size > 1 {
|
||||
branch_ctx.render_mode = RenderMode::Silent;
|
||||
}
|
||||
branch_ctx.render_mode = RenderMode::Silent;
|
||||
let script_exec_clone = script_executor.clone();
|
||||
let graph_clone = Arc::clone(&graph);
|
||||
let current = node_id.clone();
|
||||
let sem_clone = semaphore.clone();
|
||||
let abort_clone = abort_signal.clone();
|
||||
let progress_handle: Option<BranchProgressHandle> =
|
||||
progress_tracker.as_ref().map(|t| t.add_branch(node_id));
|
||||
let progress_handle = match (
|
||||
matches!(node.node_type, NodeType::Approval(_) | NodeType::Input(_)),
|
||||
&progress_tracker,
|
||||
) {
|
||||
(false, Some(tracker)) => {
|
||||
tracker.add_branch(&format!("{} ({})", node_id, node_type_label(&node)))
|
||||
}
|
||||
_ => BranchProgressHandle::disabled(),
|
||||
};
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let mut progress_handle = progress_handle;
|
||||
let mut progress_handle = Some(progress_handle);
|
||||
let _permit = sem_clone
|
||||
.acquire()
|
||||
.await
|
||||
@@ -212,9 +218,7 @@ impl GraphExecutor {
|
||||
}
|
||||
|
||||
let joined = join_all(branch_tasks).await;
|
||||
if let Some(t) = &progress_tracker {
|
||||
t.clear();
|
||||
}
|
||||
drop(progress_tracker);
|
||||
|
||||
let mut branch_writes: Vec<BranchWrites> = Vec::new();
|
||||
let mut next_frontier: HashSet<String> = HashSet::new();
|
||||
|
||||
+1
-25
@@ -3,7 +3,7 @@ use super::structured;
|
||||
use super::types::LlmNode;
|
||||
use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{Input, RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use crate::utils::create_abort_signal;
|
||||
use anyhow::{Context, Error, Result, anyhow, bail};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
@@ -101,15 +101,6 @@ async fn run(
|
||||
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
|
||||
validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?;
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!(
|
||||
"▸ llm call: model={} tools={}",
|
||||
node.model.as_deref().unwrap_or("<active>"),
|
||||
describe_tools_filter(node.tools.as_deref())
|
||||
))
|
||||
);
|
||||
|
||||
let role = build_inline_role(
|
||||
node,
|
||||
instructions.as_deref(),
|
||||
@@ -363,13 +354,6 @@ fn format_schema_hint(schema: &Value) -> String {
|
||||
)
|
||||
}
|
||||
|
||||
fn describe_tools_filter(tools: Option<&[String]>) -> String {
|
||||
match tools {
|
||||
Some(t) if !t.is_empty() => t.join(","),
|
||||
_ => "<none>".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::types::*;
|
||||
@@ -571,14 +555,6 @@ mod tests {
|
||||
assert!(hint.contains("ONLY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn describe_tools_filter_renders_each_case() {
|
||||
assert_eq!(describe_tools_filter(None), "<none>");
|
||||
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
|
||||
let tools = vec!["a".to_string(), "b".to_string()];
|
||||
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn categorize_tools_splits_mcp_and_regular() {
|
||||
let entries = vec![
|
||||
|
||||
@@ -72,10 +72,6 @@ impl GraphLogger {
|
||||
"[graph:{}] entering '{}' (visit {visit})",
|
||||
self.graph_name, node.id
|
||||
);
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!("▸ {} ({})", node.id, node_type_label(node)))
|
||||
);
|
||||
}
|
||||
|
||||
pub fn record_timing(&mut self, node_id: &str, elapsed: Duration) {
|
||||
@@ -142,7 +138,7 @@ impl GraphLogger {
|
||||
}
|
||||
}
|
||||
|
||||
fn node_type_label(node: &Node) -> &'static str {
|
||||
pub(super) fn node_type_label(node: &Node) -> &'static str {
|
||||
match &node.node_type {
|
||||
NodeType::Agent(_) => "agent",
|
||||
NodeType::Script(_) => "script",
|
||||
|
||||
+1
-1
@@ -123,7 +123,7 @@ impl MapNodeExecutor {
|
||||
}
|
||||
|
||||
let joined = join_all(sub_tasks).await;
|
||||
progress_tracker.clear();
|
||||
drop(progress_tracker);
|
||||
|
||||
// Collect outputs keyed by input index so order is preserved regardless
|
||||
// of finish order. This is the user-facing contract from plan E.2.
|
||||
|
||||
@@ -46,12 +46,6 @@ impl BranchProgressTracker {
|
||||
started: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
if let Some(multi) = &self.multi {
|
||||
let _ = multi.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct BranchProgressHandle {
|
||||
@@ -60,7 +54,7 @@ pub(super) struct BranchProgressHandle {
|
||||
}
|
||||
|
||||
impl BranchProgressHandle {
|
||||
fn disabled() -> Self {
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
bar: None,
|
||||
started: Instant::now(),
|
||||
|
||||
+1
-6
@@ -1,7 +1,7 @@
|
||||
use super::state::StateManager;
|
||||
use super::types::RagNode;
|
||||
use crate::config::RequestContext;
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use crate::utils::create_abort_signal;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use serde_json::{Map, Value};
|
||||
use std::time::Duration;
|
||||
@@ -34,11 +34,6 @@ impl RagNodeExecutor {
|
||||
let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k());
|
||||
let rerank = rag.configured_reranker();
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}"))
|
||||
);
|
||||
|
||||
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS));
|
||||
let abort = create_abort_signal();
|
||||
let (context, sources_str, _ids) =
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use super::state::{StateManager, StateRepresentation};
|
||||
use super::types::ScriptNode;
|
||||
use crate::function::Language;
|
||||
use crate::utils::dimmed_text;
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use serde_json::Value;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -32,11 +31,6 @@ impl ScriptExecutor {
|
||||
bail!("Script file not found: '{}'", script_path.display());
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!("▸ running script '{}'", node.script))
|
||||
);
|
||||
|
||||
let language = detect_language(&script_path)?;
|
||||
let state_repr = state_manager.serialize_state()?;
|
||||
|
||||
@@ -98,23 +92,6 @@ impl ScriptExecutor {
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Map<String, Value>>(json_output) {
|
||||
let keys: Vec<&str> = parsed
|
||||
.keys()
|
||||
.filter(|k| k.as_str() != "_next")
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
if !keys.is_empty() {
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!("▸ merged: {}", keys.join(", ")))
|
||||
);
|
||||
}
|
||||
if let Some(n) = &next {
|
||||
eprintln!("{}", dimmed_text(&format!("▸ script set _next = '{n}'")));
|
||||
}
|
||||
}
|
||||
|
||||
apply_state_updates(node, state_manager);
|
||||
|
||||
Ok(next)
|
||||
|
||||
+2
-13
@@ -1,6 +1,6 @@
|
||||
use crate::client::call_chat_completions;
|
||||
use crate::config::{Input, RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use crate::utils::create_abort_signal;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
@@ -24,10 +24,6 @@ pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext)
|
||||
return Ok(parsed);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor")
|
||||
);
|
||||
extract_via_extractor(raw, schema, parent_ctx, false).await
|
||||
}
|
||||
|
||||
@@ -53,14 +49,7 @@ async fn extract_via_extractor(
|
||||
"Structured-output extractor failed to produce valid JSON after repair retry. \
|
||||
Last response:\n{output}"
|
||||
),
|
||||
None => {
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying")
|
||||
);
|
||||
|
||||
Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await
|
||||
}
|
||||
None => Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user