feat: Full support for map node types
This commit is contained in:
+46
-18
@@ -1,6 +1,7 @@
|
||||
use super::agent::AgentNodeExecutor;
|
||||
use super::llm::LlmNodeExecutor;
|
||||
use super::logging::GraphLogger;
|
||||
use super::map::MapNodeExecutor;
|
||||
use super::rag::RagNodeExecutor;
|
||||
use super::script::ScriptExecutor;
|
||||
use super::staging::BranchWrites;
|
||||
@@ -74,6 +75,10 @@ impl GraphExecutor {
|
||||
let max_iterations = graph.settings.max_loop_iterations;
|
||||
let graph_timeout = graph.settings.timeout.map(Duration::from_secs);
|
||||
let max_concurrency = graph.settings.max_concurrency;
|
||||
// Wrap in Arc so spawned branch tasks can cheaply share the Graph for
|
||||
// node lookup (especially the map executor, which needs to resolve its
|
||||
// `branch:` target from inside a spawned task).
|
||||
let graph = Arc::new(graph);
|
||||
let start = Instant::now();
|
||||
|
||||
let mut frontier: HashSet<String> = HashSet::from([graph.start.clone()]);
|
||||
@@ -150,7 +155,7 @@ impl GraphExecutor {
|
||||
let branch_state = state.fork_for_branch_state();
|
||||
let branch_ctx = ctx.fork_for_branch();
|
||||
let script_exec_clone = script_executor.clone();
|
||||
let graph_name = graph.name.clone();
|
||||
let graph_clone = Arc::clone(&graph);
|
||||
let current = node_id.clone();
|
||||
let sem_clone = semaphore.clone();
|
||||
let abort_clone = abort_signal.clone();
|
||||
@@ -171,15 +176,13 @@ impl GraphExecutor {
|
||||
let node_start = Instant::now();
|
||||
let mut state = branch_state;
|
||||
let mut ctx = branch_ctx;
|
||||
let result = step(
|
||||
&node,
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
&script_exec_clone,
|
||||
&graph_name,
|
||||
¤t,
|
||||
)
|
||||
.await;
|
||||
let step_ctx = StepContext {
|
||||
graph: graph_clone.as_ref(),
|
||||
script_executor: &script_exec_clone,
|
||||
max_concurrency,
|
||||
abort_signal: &abort_clone,
|
||||
};
|
||||
let result = step(&node, &mut state, &mut ctx, &step_ctx, ¤t).await;
|
||||
let elapsed = node_start.elapsed();
|
||||
(current, state, result, elapsed)
|
||||
});
|
||||
@@ -263,6 +266,23 @@ fn sorted_frontier(frontier: &HashSet<String>) -> Vec<String> {
|
||||
v
|
||||
}
|
||||
|
||||
// Bundles the engine-config refs that every `step()` call needs to thread
|
||||
// through. Constructed once per spawned branch task (or once at the call site
|
||||
// for sequential paths) so step() and downstream executors (MapNodeExecutor)
|
||||
// take one parameter instead of five.
|
||||
pub(super) struct StepContext<'a> {
|
||||
pub graph: &'a Graph,
|
||||
pub script_executor: &'a ScriptExecutor,
|
||||
pub max_concurrency: usize,
|
||||
pub abort_signal: &'a AbortSignal,
|
||||
}
|
||||
|
||||
impl StepContext<'_> {
|
||||
pub fn graph_name(&self) -> &str {
|
||||
&self.graph.name
|
||||
}
|
||||
}
|
||||
|
||||
enum StepResult {
|
||||
Continue(String),
|
||||
End(String),
|
||||
@@ -272,8 +292,7 @@ async fn step(
|
||||
node: &Node,
|
||||
state: &mut StateManager,
|
||||
ctx: &mut RequestContext,
|
||||
script_executor: &ScriptExecutor,
|
||||
graph_name: &str,
|
||||
step_ctx: &StepContext<'_>,
|
||||
current: &str,
|
||||
) -> Result<StepResult> {
|
||||
match &node.node_type {
|
||||
@@ -288,13 +307,16 @@ async fn step(
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Script(script_node) => {
|
||||
let dynamic = match script_executor.execute(script_node, state).await {
|
||||
let dynamic = match step_ctx.script_executor.execute(script_node, state).await {
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
if let Some(fallback) = &script_node.fallback {
|
||||
warn!(
|
||||
"[graph:{}] script '{}' failed, routing to fallback '{}': {}",
|
||||
graph_name, current, fallback, e
|
||||
step_ctx.graph_name(),
|
||||
current,
|
||||
fallback,
|
||||
e
|
||||
);
|
||||
return Ok(StepResult::Continue(fallback.clone()));
|
||||
}
|
||||
@@ -334,10 +356,16 @@ async fn step(
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||
NodeType::Map(_) => bail!(
|
||||
"Map nodes are not yet supported in this build \
|
||||
(parallel branch execution lands in Phase D/E)."
|
||||
),
|
||||
NodeType::Map(map_node) => {
|
||||
let next = node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!("map node '{current}' has no `next` and is not an end node")
|
||||
})?
|
||||
.to_string();
|
||||
MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
use super::agent::AgentNodeExecutor;
|
||||
use super::executor::StepContext;
|
||||
use super::llm::LlmNodeExecutor;
|
||||
use super::rag::RagNodeExecutor;
|
||||
use super::state::StateManager;
|
||||
use super::types::{MapNode, NodeType};
|
||||
use crate::config::RequestContext;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use futures_util::future::join_all;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
use crate::graph::type_name;
|
||||
|
||||
// Map sub-branches are atomic — the branch node has no `next` (enforced by
|
||||
// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for
|
||||
// their routing argument and error if it's `None`. Passing this sentinel
|
||||
// satisfies their contract; the map executor discards the returned routing.
|
||||
const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__";
|
||||
|
||||
pub(super) struct MapNodeExecutor;
|
||||
|
||||
impl MapNodeExecutor {
|
||||
pub(super) async fn execute(
|
||||
node: &MapNode,
|
||||
state: &mut StateManager,
|
||||
ctx: &mut RequestContext,
|
||||
step_ctx: &StepContext<'_>,
|
||||
node_id: &str,
|
||||
) -> Result<()> {
|
||||
let over_value = state
|
||||
.interpolate_raw(&node.over)
|
||||
.with_context(|| format!("map node '{node_id}': evaluating `over` template"))?;
|
||||
|
||||
let items = over_value.as_array().ok_or_else(|| {
|
||||
anyhow!(
|
||||
"map node '{}': `over` template '{}' must resolve to an array, got {}",
|
||||
node_id,
|
||||
node.over,
|
||||
type_name(&over_value)
|
||||
)
|
||||
})?;
|
||||
let items = items.clone();
|
||||
|
||||
let branch_node = step_ctx
|
||||
.graph
|
||||
.get_node(&node.branch)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"map node '{node_id}': branch '{}' not found in graph",
|
||||
node.branch
|
||||
)
|
||||
})?
|
||||
.clone();
|
||||
|
||||
let max_conc = node
|
||||
.max_concurrency
|
||||
.unwrap_or(step_ctx.max_concurrency)
|
||||
.max(1);
|
||||
let semaphore = Arc::new(Semaphore::new(max_conc));
|
||||
let mut sub_tasks = Vec::with_capacity(items.len());
|
||||
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
let item = item.clone();
|
||||
let as_name = node.as_name.clone();
|
||||
let branch_clone = branch_node.clone();
|
||||
let mut sub_state = state.fork_for_branch_state();
|
||||
let sub_ctx = ctx.fork_for_branch();
|
||||
let script_clone = step_ctx.script_executor.clone();
|
||||
let sub_branch_id = node.branch.clone();
|
||||
let sem = semaphore.clone();
|
||||
let abort = step_ctx.abort_signal.clone();
|
||||
|
||||
sub_state.state_mut().set(as_name, item);
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let _permit = sem
|
||||
.acquire()
|
||||
.await
|
||||
.expect("map semaphore should not be closed");
|
||||
if abort.aborted() {
|
||||
return (
|
||||
idx,
|
||||
sub_state,
|
||||
Err(anyhow!("map sub-branch [{idx}] aborted")),
|
||||
);
|
||||
}
|
||||
let mut state = sub_state;
|
||||
let mut ctx = sub_ctx;
|
||||
|
||||
let exec_result: Result<()> = match &branch_clone.node_type {
|
||||
NodeType::Llm(n) => LlmNodeExecutor::execute(
|
||||
n,
|
||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Rag(n) => RagNodeExecutor::execute(
|
||||
n,
|
||||
&sub_branch_id,
|
||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
|
||||
_ => Err(anyhow!(
|
||||
"map branch '{}' has type that cannot run inside a map \
|
||||
(validator should have caught this; internal error)",
|
||||
branch_clone.id
|
||||
)),
|
||||
};
|
||||
|
||||
(idx, state, exec_result)
|
||||
});
|
||||
sub_tasks.push(task);
|
||||
}
|
||||
|
||||
let joined = join_all(sub_tasks).await;
|
||||
|
||||
// Collect outputs keyed by input index so order is preserved regardless
|
||||
// of finish order. This is the user-facing contract from plan E.2.
|
||||
let mut outputs: HashMap<usize, Value> = HashMap::new();
|
||||
for join_result in joined {
|
||||
let (idx, sub_state, exec_result) =
|
||||
join_result.map_err(|e| anyhow!("map sub-branch panicked: {e}"))?;
|
||||
|
||||
exec_result
|
||||
.with_context(|| format!("map node '{node_id}': sub-branch [{idx}] failed"))?;
|
||||
|
||||
let output_value = sub_state
|
||||
.state()
|
||||
.get(&node.output_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"map node '{node_id}': sub-branch [{idx}] did not write \
|
||||
`output_key` '{}'",
|
||||
node.output_key
|
||||
)
|
||||
})?;
|
||||
|
||||
outputs.insert(idx, output_value);
|
||||
}
|
||||
|
||||
let mut collected = Vec::with_capacity(items.len());
|
||||
for idx in 0..items.len() {
|
||||
let value = outputs.remove(&idx).ok_or_else(|| {
|
||||
anyhow!(
|
||||
"map node '{node_id}': internal error: missing result for sub-branch [{idx}]"
|
||||
)
|
||||
})?;
|
||||
collected.push(value);
|
||||
}
|
||||
|
||||
state
|
||||
.state_mut()
|
||||
.set(node.collect_into.clone(), Value::Array(collected));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ pub mod dispatch;
|
||||
pub mod executor;
|
||||
pub mod llm;
|
||||
pub mod logging;
|
||||
pub mod map;
|
||||
pub mod parser;
|
||||
pub mod rag;
|
||||
pub mod reducer;
|
||||
@@ -14,6 +15,7 @@ pub mod types;
|
||||
pub mod user_interaction;
|
||||
pub mod validator;
|
||||
|
||||
use serde_json::Value;
|
||||
pub use dispatch::{active_agent_graph_name, run_active_agent_graph};
|
||||
pub use executor::GraphExecutor;
|
||||
pub use parser::{GraphParser, agent_has_graph};
|
||||
@@ -24,3 +26,14 @@ pub const GRAPH_SCHEMA_VERSION: &str = "1.0";
|
||||
pub const DEFAULT_MAX_LOOP_ITERATIONS: usize = 100;
|
||||
|
||||
pub const MAX_STATE_SIZE_BYTES: usize = 32 * 1024;
|
||||
|
||||
pub (in crate::graph) fn type_name(value: &Value) -> &'static str {
|
||||
match value {
|
||||
Value::Null => "null",
|
||||
Value::Bool(_) => "bool",
|
||||
Value::Number(_) => "number",
|
||||
Value::String(_) => "string",
|
||||
Value::Array(_) => "array",
|
||||
Value::Object(_) => "object",
|
||||
}
|
||||
}
|
||||
|
||||
+1
-11
@@ -1,6 +1,7 @@
|
||||
use super::types::Reducer;
|
||||
use anyhow::{Result, bail};
|
||||
use serde_json::{Number, Value};
|
||||
use crate::graph::type_name;
|
||||
|
||||
/// Combines a branch's incoming write with the current state value (if any)
|
||||
/// via the specified reducer. The result is what gets written back to live
|
||||
@@ -147,17 +148,6 @@ fn number_or_error(value: &Value, reducer_name: &str, position: &str) -> Result<
|
||||
}
|
||||
}
|
||||
|
||||
fn type_name(value: &Value) -> &'static str {
|
||||
match value {
|
||||
Value::Null => "null",
|
||||
Value::Bool(_) => "bool",
|
||||
Value::Number(_) => "number",
|
||||
Value::String(_) => "string",
|
||||
Value::Array(_) => "array",
|
||||
Value::Object(_) => "object",
|
||||
}
|
||||
}
|
||||
|
||||
// Numeric reducers compute in f64 for simplicity. We preserve integer typing
|
||||
// when the result is losslessly representable as i64 so `count: sum` stays an
|
||||
// integer rather than degrading to a float. Non-finite values (NaN, Inf) can't
|
||||
|
||||
Reference in New Issue
Block a user