Files
coyote/src/graph/map.rs
T

170 lines
6.1 KiB
Rust

use super::agent::AgentNodeExecutor;
use super::executor::StepContext;
use super::llm::LlmNodeExecutor;
use super::progress::{BranchProgressHandle, BranchProgressTracker};
use super::rag::RagNodeExecutor;
use super::state::StateManager;
use super::types::{MapNode, NodeType};
use crate::config::{RenderMode, RequestContext};
use crate::graph::type_name;
use anyhow::{Context, Result, anyhow};
use futures_util::future::join_all;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
pub(super) struct MapNodeExecutor;
impl MapNodeExecutor {
pub(super) async fn execute(
node: &MapNode,
state: &mut StateManager,
ctx: &mut RequestContext,
step_ctx: &StepContext<'_>,
node_id: &str,
) -> Result<()> {
let over_value = state
.interpolate_raw(&node.over)
.with_context(|| format!("map node '{node_id}': evaluating `over` template"))?;
let items = over_value.as_array().ok_or_else(|| {
anyhow!(
"map node '{}': `over` template '{}' must resolve to an array, got {}",
node_id,
node.over,
type_name(&over_value)
)
})?;
let items = items.clone();
let branch_node = step_ctx
.graph
.get_node(&node.branch)
.ok_or_else(|| {
anyhow!(
"map node '{node_id}': branch '{}' not found in graph",
node.branch
)
})?
.clone();
let max_conc = node
.max_concurrency
.unwrap_or(step_ctx.max_concurrency)
.max(1);
let semaphore = Arc::new(Semaphore::new(max_conc));
let progress_tracker = BranchProgressTracker::new();
let mut sub_tasks = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
let item = item.clone();
let as_name = node.as_name.clone();
let branch_clone = branch_node.clone();
let mut sub_state = state.fork_for_branch_state();
let mut sub_ctx = ctx.fork_for_branch();
sub_ctx.render_mode = RenderMode::Silent;
let script_clone = step_ctx.script_executor.clone();
let sub_branch_id = node.branch.clone();
let sem = semaphore.clone();
let abort = step_ctx.abort_signal.clone();
let progress_handle: BranchProgressHandle =
progress_tracker.add_branch(&format!("{}[{idx}]", node.branch));
sub_state.state_mut().set(as_name, item);
let task = tokio::spawn(async move {
let mut progress_handle = Some(progress_handle);
let _permit = sem
.acquire()
.await
.expect("map semaphore should not be closed");
if abort.aborted() {
if let Some(h) = progress_handle.take() {
h.fail("aborted");
}
return (
idx,
sub_state,
Err(anyhow!("map sub-branch [{idx}] aborted")),
);
}
let mut state = sub_state;
let mut ctx = sub_ctx;
let exec_result: Result<()> = match &branch_clone.node_type {
NodeType::Llm(n) => LlmNodeExecutor::execute(n, &mut state, &mut ctx)
.await
.map(|_| ()),
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
.await
.map(|_| ()),
NodeType::Rag(n) => {
RagNodeExecutor::execute(n, &sub_branch_id, &mut state, &mut ctx).await
}
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
_ => Err(anyhow!(
"map branch '{}' has type that cannot run inside a map \
(validator should have caught this; internal error)",
branch_clone.id
)),
};
if let Some(h) = progress_handle.take() {
match &exec_result {
Ok(_) => h.complete(),
Err(e) => h.fail(&e.to_string()),
}
}
(idx, state, exec_result)
});
sub_tasks.push(task);
}
let joined = join_all(sub_tasks).await;
drop(progress_tracker);
// Collect outputs keyed by input index so order is preserved regardless
// of finish order. This is the user-facing contract from plan E.2.
let mut outputs: HashMap<usize, Value> = HashMap::new();
for join_result in joined {
let (idx, sub_state, exec_result) =
join_result.map_err(|e| anyhow!("map sub-branch panicked: {e}"))?;
exec_result
.with_context(|| format!("map node '{node_id}': sub-branch [{idx}] failed"))?;
let output_value = sub_state
.state()
.get(&node.output_key)
.cloned()
.ok_or_else(|| {
anyhow!(
"map node '{node_id}': sub-branch [{idx}] did not write \
`output_key` '{}'",
node.output_key
)
})?;
outputs.insert(idx, output_value);
}
let mut collected = Vec::with_capacity(items.len());
for idx in 0..items.len() {
let value = outputs.remove(&idx).ok_or_else(|| {
anyhow!(
"map node '{node_id}': internal error: missing result for sub-branch [{idx}]"
)
})?;
collected.push(value);
}
state
.state_mut()
.set(node.collect_into.clone(), Value::Array(collected));
Ok(())
}
}