From dfd1334dece611b69cb64792c7671375879432d0 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Tue, 12 May 2026 14:13:03 -0600 Subject: [PATCH] feat: initial agent graph scaffolding --- src/config/mod.rs | 1 + src/config/paths.rs | 10 +- src/function/supervisor.rs | 10 +- src/graph/mod.rs | 19 ++ src/graph/parser.rs | 456 ++++++++++++++++++++++++++++++ src/graph/types.rs | 548 +++++++++++++++++++++++++++++++++++++ src/main.rs | 1 + 7 files changed, 1036 insertions(+), 9 deletions(-) create mode 100644 src/graph/mod.rs create mode 100644 src/graph/parser.rs create mode 100644 src/graph/types.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 7e3bb9f..32dfa32 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -66,6 +66,7 @@ const DARK_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bi const LIGHT_THEME: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin"); const CONFIG_FILE_NAME: &str = "config.yaml"; +const AGENT_GRAPH_FILE_NAME: &str = "graph.yaml"; const ROLES_DIR_NAME: &str = "roles"; const MACROS_DIR_NAME: &str = "macros"; const ENV_FILE_NAME: &str = ".env"; diff --git a/src/config/paths.rs b/src/config/paths.rs index 947fea9..aa2ce14 100644 --- a/src/config/paths.rs +++ b/src/config/paths.rs @@ -1,9 +1,5 @@ use super::role::Role; -use super::{ - AGENTS_DIR_NAME, BASH_PROMPT_UTILS_FILE_NAME, CONFIG_FILE_NAME, ENV_FILE_NAME, - FUNCTIONS_BIN_DIR_NAME, FUNCTIONS_DIR_NAME, GLOBAL_TOOLS_DIR_NAME, GLOBAL_TOOLS_UTILS_DIR_NAME, - MACROS_DIR_NAME, MCP_FILE_NAME, ModelsOverride, RAGS_DIR_NAME, ROLES_DIR_NAME, -}; +use super::{AGENTS_DIR_NAME, BASH_PROMPT_UTILS_FILE_NAME, CONFIG_FILE_NAME, ENV_FILE_NAME, FUNCTIONS_BIN_DIR_NAME, FUNCTIONS_DIR_NAME, GLOBAL_TOOLS_DIR_NAME, GLOBAL_TOOLS_UTILS_DIR_NAME, MACROS_DIR_NAME, MCP_FILE_NAME, ModelsOverride, RAGS_DIR_NAME, ROLES_DIR_NAME, paths, AGENT_GRAPH_FILE_NAME}; use crate::client::ProviderModels; use crate::utils::{get_env_name, list_file_names, normalize_env_name}; @@ -127,6 +123,10 @@ pub fn agent_data_dir(name: &str) -> PathBuf { } } +pub fn agent_graph_path(agent_name: &str) -> PathBuf { + agent_data_dir(agent_name).join(AGENT_GRAPH_FILE_NAME) +} + pub fn agent_config_file(name: &str) -> PathBuf { match env::var(format!("{}_CONFIG_FILE", normalize_env_name(name))) { Ok(value) => PathBuf::from(value), diff --git a/src/function/supervisor.rs b/src/function/supervisor.rs index 56b8455..00f3940 100644 --- a/src/function/supervisor.rs +++ b/src/function/supervisor.rs @@ -610,8 +610,9 @@ async fn handle_check(ctx: &mut RequestContext, args: &Value) -> Result { "message": "Agent is still running" }); - if let Some(queue) = ctx.root_escalation_queue() && - queue.has_pending() { + if let Some(queue) = ctx.root_escalation_queue() + && queue.has_pending() + { let summary = queue.pending_summary(); result["pending_escalations"] = json!(summary); result["message"] = json!( @@ -660,8 +661,9 @@ async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result break; } - if let Some(queue) = ctx.root_escalation_queue() && - queue.has_pending() { + if let Some(queue) = ctx.root_escalation_queue() + && queue.has_pending() + { let summary = queue.pending_summary(); return Ok(json!({ "status": "pending", diff --git a/src/graph/mod.rs b/src/graph/mod.rs new file mode 100644 index 0000000..95cea4b --- /dev/null +++ b/src/graph/mod.rs @@ -0,0 +1,19 @@ +//! Graph-based agent orchestration. Declarative YAML workflows over a shared +//! JSON state, composed of agent/script/approval/input/end nodes. + +pub mod parser; +pub mod types; + +pub use parser::{GraphParser, agent_has_graph, load_agent_graph}; +pub use types::{ + AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, Node, NodeType, + ScriptNode, +}; + +pub const GRAPH_SCHEMA_VERSION: &str = "1.0"; + +pub const DEFAULT_MAX_LOOP_ITERATIONS: usize = 100; + +/// Serialized-state size above which scripts receive state via a temp file +/// instead of an env var. +pub const MAX_STATE_SIZE_BYTES: usize = 32 * 1024; diff --git a/src/graph/parser.rs b/src/graph/parser.rs new file mode 100644 index 0000000..f215653 --- /dev/null +++ b/src/graph/parser.rs @@ -0,0 +1,456 @@ +//! YAML parsing for graph definitions. + +use super::types::Graph; +use crate::config::paths; +use anyhow::{Context, Result, bail, Error, anyhow}; +use std::fs::read_to_string; +use std::path::{Path, PathBuf}; + +const SUPPORTED_VERSIONS: &[&str] = &["1.0"]; + +/// Parser for graph YAML files. The `base_dir` is used to resolve relative +/// paths passed to [`GraphParser::load_from_file`], and is typically an +/// agent directory. +pub struct GraphParser { + base_dir: PathBuf, +} + +impl GraphParser { + pub fn new(base_dir: impl Into) -> Self { + Self { + base_dir: base_dir.into(), + } + } + + /// Load and validate a graph from a YAML file. Relative paths are + /// resolved against `base_dir`. + pub fn load_from_file(&self, path: impl AsRef) -> Result { + let path = path.as_ref(); + let full_path = if path.is_absolute() { + path.to_path_buf() + } else { + self.base_dir.join(path) + }; + + let contents = read_to_string(&full_path) + .with_context(|| format!("Failed to read graph file at '{}'", full_path.display()))?; + + self.load_from_string(&contents) + .with_context(|| format!("Failed to parse graph file at '{}'", full_path.display())) + } + + /// Load and validate a graph from a YAML string. + pub fn load_from_string(&self, yaml: &str) -> Result { + let mut graph: Graph = serde_yaml::from_str(yaml).map_err(enhance_yaml_error)?; + + validate_schema_version(&graph.version)?; + + for (key, node) in &mut graph.nodes { + if node.id.is_empty() { + node.id = key.clone(); + } else if &node.id != key { + bail!( + "Node ID mismatch: key '{}' does not match node.id '{}'", + key, + node.id + ); + } + } + + validate_structure(&graph)?; + + Ok(graph) + } +} + +fn validate_schema_version(version: &str) -> Result<()> { + if !SUPPORTED_VERSIONS.contains(&version) { + bail!( + "Unsupported graph schema version '{}'. Supported versions: {}", + version, + SUPPORTED_VERSIONS.join(", ") + ); + } + Ok(()) +} + +fn validate_structure(graph: &Graph) -> Result<()> { + if graph.name.is_empty() { + bail!("Graph must have a non-empty 'name' field"); + } + + if graph.nodes.is_empty() { + bail!("Graph '{}' has no nodes defined", graph.name); + } + + if !graph.has_node(&graph.start) { + bail!( + "Start node '{}' not found in graph '{}'. Available nodes: {}", + graph.start, + graph.name, + graph.node_ids().join(", ") + ); + } + + Ok(()) +} + +fn enhance_yaml_error(error: serde_yaml::Error) -> Error { + let msg = error.to_string(); + + let hint = if msg.contains("missing field") { + "\n\nHint: Check that all required fields are present.\n\ + Top-level required fields: `name`, `start`, `nodes`.\n\ + Each node requires `type` plus that type's fields:\n\ + - agent: `agent`, `prompt`\n\ + - script: `script`\n\ + - approval: `question`, `options`, `routes`\n\ + - input: `question`\n\ + - end: (no required fields)" + } else if msg.contains("unknown field") || msg.contains("unknown variant") { + "\n\nHint: Check for typos in field names or `type:` values.\n\ + Valid node types: agent, script, approval, input, end." + } else if msg.contains("invalid type") { + "\n\nHint: Check that field values have the correct type.\n\ + - Strings should be quoted if they contain special characters\n\ + - Numbers should not be quoted\n\ + - Lists use YAML array syntax (- item1)\n\ + - Maps use YAML object syntax (key: value)" + } else { + "" + }; + + anyhow!("YAML parsing error: {}{}", msg, hint) +} + +/// Returns true if the named agent has a `graph.yaml` in its data directory. +pub fn agent_has_graph(agent_name: &str) -> bool { + paths::agent_graph_path(agent_name).exists() +} + +/// Load `graph.yaml` from the named agent's data directory. Returns `Ok(None)` +/// if no graph file exists. +pub fn load_agent_graph(agent_name: &str) -> Result> { + let graph_path = paths::agent_graph_path(agent_name); + if !graph_path.exists() { + return Ok(None); + } + + let parser = GraphParser::new(paths::agent_data_dir(agent_name)); + let graph = parser.load_from_file(&graph_path)?; + Ok(Some(graph)) +} + +#[cfg(test)] +mod tests { + use super::super::types::NodeType; + use super::*; + use std::env; + + fn parser() -> GraphParser { + GraphParser::new(env::current_dir().unwrap()) + } + + #[test] + fn parses_a_simple_graph() { + let yaml = r#" +name: simple_graph +version: "1.0" +start: node1 +nodes: + node1: + id: node1 + type: agent + agent: test_agent + prompt: "Hello world" + next: node2 + node2: + id: node2 + type: end + output: done +"#; + let graph = parser().load_from_string(yaml).unwrap(); + assert_eq!(graph.name, "simple_graph"); + assert_eq!(graph.start, "node1"); + assert_eq!(graph.nodes.len(), 2); + assert_eq!( + graph.nodes.get("node1").unwrap().next.as_deref(), + Some("node2") + ); + } + + #[test] + fn auto_fills_node_ids_from_keys() { + let yaml = r#" +name: auto_id_graph +version: "1.0" +start: node1 +nodes: + node1: + type: agent + agent: test_agent + prompt: Test + next: node2 + node2: + type: end + output: done +"#; + let graph = parser().load_from_string(yaml).unwrap(); + assert_eq!(graph.nodes.get("node1").unwrap().id, "node1"); + assert_eq!(graph.nodes.get("node2").unwrap().id, "node2"); + } + + #[test] + fn rejects_missing_start_node() { + let yaml = r#" +name: bad_graph +version: "1.0" +start: nonexistent +nodes: + node1: + type: end +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!( + err.contains("Start node 'nonexistent' not found"), + "got: {err}" + ); + } + + #[test] + fn rejects_empty_graph_name() { + let yaml = r#" +name: "" +version: "1.0" +start: node1 +nodes: + node1: + type: end +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!(err.contains("non-empty 'name'"), "got: {err}"); + } + + #[test] + fn rejects_no_nodes() { + let yaml = r#" +name: empty_graph +version: "1.0" +start: node1 +nodes: {} +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!(err.contains("no nodes defined"), "got: {err}"); + } + + #[test] + fn rejects_unsupported_version() { + let yaml = r#" +name: future_graph +version: "2.0" +start: node1 +nodes: + node1: + type: end +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!( + err.contains("Unsupported graph schema version"), + "got: {err}" + ); + } + + #[test] + fn rejects_node_id_mismatch() { + let yaml = r#" +name: mismatch_graph +version: "1.0" +start: node1 +nodes: + node1: + id: different_id + type: end +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!(err.contains("Node ID mismatch"), "got: {err}"); + } + + #[test] + fn parses_approval_node_with_routes() { + let yaml = r#" +name: approval_graph +version: "1.0" +start: approval1 +nodes: + approval1: + type: approval + question: "Proceed with deployment?" + options: + - "Yes" + - "No" + routes: + "Yes": deploy + "No": cancel + deploy: + type: end + cancel: + type: end +"#; + let graph = parser().load_from_string(yaml).unwrap(); + let approval = graph.nodes.get("approval1").unwrap(); + match &approval.node_type { + NodeType::Approval(a) => { + assert_eq!(a.options.len(), 2); + assert_eq!(a.routes.len(), 2); + assert_eq!(a.routes.get("Yes").map(|s| s.as_str()), Some("deploy")); + } + _ => panic!("expected approval node"), + } + } + + #[test] + fn parses_settings_overrides() { + let yaml = r#" +name: settings_graph +version: "1.0" +start: node1 +settings: + max_loop_iterations: 50 + timeout: 300 + log_state_snapshots: false +nodes: + node1: + type: end +"#; + let graph = parser().load_from_string(yaml).unwrap(); + assert_eq!(graph.settings.max_loop_iterations, 50); + assert_eq!(graph.settings.timeout, Some(300)); + assert!(!graph.settings.log_state_snapshots); + assert!(graph.settings.validate_before_run); + } + + #[test] + fn parses_initial_state() { + let yaml = r#" +name: state_graph +version: "1.0" +start: node1 +initial_state: + user_name: "Alice" + count: 42 + enabled: true +nodes: + node1: + type: end +"#; + let graph = parser().load_from_string(yaml).unwrap(); + assert_eq!(graph.initial_state.len(), 3); + assert_eq!(graph.initial_state.get("user_name").unwrap(), "Alice"); + assert_eq!( + graph.initial_state.get("count").unwrap(), + &serde_json::json!(42) + ); + assert_eq!( + graph.initial_state.get("enabled").unwrap(), + &serde_json::json!(true) + ); + } + + #[test] + fn uses_default_version_when_absent() { + let yaml = r#" +name: no_version +start: node1 +nodes: + node1: + type: end +"#; + let graph = parser().load_from_string(yaml).unwrap(); + assert_eq!(graph.version, super::super::GRAPH_SCHEMA_VERSION); + } + + #[test] + fn rejects_unknown_node_type_with_hint() { + let yaml = r#" +name: bad_type +version: "1.0" +start: node1 +nodes: + node1: + type: nonsense +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!( + err.contains("Valid node types") || err.contains("unknown variant"), + "got: {err}" + ); + } + + #[test] + fn rejects_malformed_yaml() { + let yaml = "name: bad\n bad: indent\nstart: a"; + let result = parser().load_from_string(yaml); + assert!(result.is_err()); + } + + #[test] + fn missing_required_fields_have_a_hint() { + let yaml = r#" +name: missing_start +version: "1.0" +nodes: + node1: + type: end +"#; + let err = parser().load_from_string(yaml).unwrap_err().to_string(); + assert!(err.contains("Hint"), "got: {err}"); + } + + #[test] + fn load_from_file_reads_disk() { + use std::io::Write; + let dir = env::temp_dir(); + let path = dir.join(format!( + "loki_graph_parser_test_{}.yaml", + std::process::id() + )); + let yaml = r#" +name: disk_graph +version: "1.0" +start: only +nodes: + only: + type: end + output: ok +"#; + { + let mut f = std::fs::File::create(&path).unwrap(); + f.write_all(yaml.as_bytes()).unwrap(); + } + + let graph = GraphParser::new(dir).load_from_file(&path).unwrap(); + assert_eq!(graph.name, "disk_graph"); + + let _ = std::fs::remove_file(&path); + } + + #[test] + fn load_from_file_errors_on_missing_path() { + let err = parser() + .load_from_file("/definitely/not/a/real/path/to_any_graph.yaml") + .unwrap_err() + .to_string(); + assert!(err.contains("Failed to read graph file"), "got: {err}"); + } + + #[test] + fn agent_has_graph_false_for_unknown_agent() { + assert!(!agent_has_graph("__nonexistent_agent_for_test__")); + } + + #[test] + fn load_agent_graph_returns_none_when_absent() { + let result = load_agent_graph("__nonexistent_agent_for_test__").unwrap(); + assert!(result.is_none()); + } +} diff --git a/src/graph/types.rs b/src/graph/types.rs new file mode 100644 index 0000000..e992aa0 --- /dev/null +++ b/src/graph/types.rs @@ -0,0 +1,548 @@ +//! Core data structures for graph-based agent orchestration. + +use anyhow::Result; +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// A graph definition loaded from YAML. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Graph { + pub name: String, + + #[serde(default)] + pub description: String, + + #[serde(default = "default_schema_version")] + pub version: String, + + #[serde(default)] + pub settings: GraphSettings, + + #[serde(default)] + pub initial_state: HashMap, + + pub start: String, + + pub nodes: IndexMap, +} + +impl Graph { + pub fn get_node(&self, id: &str) -> Option<&Node> { + self.nodes.get(id) + } + + pub fn has_node(&self, id: &str) -> bool { + self.nodes.contains_key(id) + } + + pub fn node_ids(&self) -> Vec<&str> { + self.nodes.keys().map(|s| s.as_str()).collect() + } +} + +fn default_schema_version() -> String { + super::GRAPH_SCHEMA_VERSION.to_string() +} + +/// Graph-level settings. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct GraphSettings { + #[serde(default = "default_max_loop_iterations")] + pub max_loop_iterations: usize, + + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + + #[serde(default = "default_true")] + pub log_state_snapshots: bool, + + #[serde(default = "default_true")] + pub validate_before_run: bool, +} + +impl Default for GraphSettings { + fn default() -> Self { + Self { + max_loop_iterations: default_max_loop_iterations(), + timeout: None, + log_state_snapshots: true, + validate_before_run: true, + } + } +} + +fn default_max_loop_iterations() -> usize { + super::DEFAULT_MAX_LOOP_ITERATIONS +} + +fn default_true() -> bool { + true +} + +/// A node in the graph. `node_type` is flattened into the YAML, so a node's +/// variant-specific fields live alongside `id`, `description`. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Node { + /// Unique node identifier. May be omitted in YAML; the parser fills it + /// in from the surrounding `nodes:` map key. + #[serde(default)] + pub id: String, + + #[serde(default)] + pub description: String, + + #[serde(flatten)] + pub node_type: NodeType, + + /// Static next-node routing. Used by agent/input nodes. + /// Approval nodes use their `routes` map instead. + /// Script nodes: this is populated by `_next` in JSON output. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub next: Option, +} + +/// The supported node variants. YAML uses an internal `type` tag in lowercase +/// (e.g. `type: agent`). +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum NodeType { + Agent(AgentNode), + Script(ScriptNode), + Approval(ApprovalNode), + Input(InputNode), + End(EndNode), +} + +/// `agent`-type node: spawn an agent with a templated prompt. Agent tools +/// come from the agent's own `config.yaml`; create agent variants for +/// different tool sets rather than overriding here. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AgentNode { + pub agent: String, + + pub prompt: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +/// `script`-type node: run a Python/TypeScript/Bash script that prints a +/// JSON object on stdout. Keys merge into state; the special `_next` key +/// overrides routing and is not merged. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ScriptNode { + pub script: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, + + /// Fallback node to route to if the script fails to run or returns empty + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fallback: Option, + + #[serde(default = "default_script_timeout")] + pub timeout: u64, +} + +fn default_script_timeout() -> u64 { + 30 +} + +/// `approval`-type node: prompt the user with `options` and route based on +/// their choice via the `routes` map. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ApprovalNode { + pub question: String, + + pub options: Vec, + + pub routes: HashMap, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub on_timeout: Option, +} + +/// `input`-type node: collect free-form text from the user. Routes via the +/// top-level `next` field; the user's text is exposed to templates as +/// `{{input}}` in `state_updates`. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InputNode { + pub question: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub validation: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub on_timeout: Option, +} + +/// `end`-type node: terminate execution; `output` (templated) is returned +/// as the graph's final result. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EndNode { + #[serde(default)] + pub output: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, +} + +/// Runtime state for a graph execution: KV store plus visit history. +#[derive(Debug, Clone, Default)] +pub struct GraphState { + data: HashMap, + history: Vec, + loop_counts: HashMap, +} + +impl GraphState { + pub fn new(initial: HashMap) -> Self { + Self { + data: initial, + history: Vec::new(), + loop_counts: HashMap::new(), + } + } + + pub fn get(&self, key: &str) -> Option<&Value> { + self.data.get(key) + } + + pub fn set(&mut self, key: String, value: Value) { + self.data.insert(key, value); + } + + /// Merge a JSON object into state. Existing keys are overwritten. + pub fn merge(&mut self, json_obj: &serde_json::Map) { + for (key, value) in json_obj { + self.data.insert(key.clone(), value.clone()); + } + } + + pub fn data(&self) -> &HashMap { + &self.data + } + + /// Record that a node has been entered. Updates both history and loop + /// counts. + pub fn visit_node(&mut self, node_id: &str) { + self.history.push(node_id.to_string()); + *self.loop_counts.entry(node_id.to_string()).or_insert(0) += 1; + } + + pub fn loop_count(&self, node_id: &str) -> usize { + self.loop_counts.get(node_id).copied().unwrap_or(0) + } + + pub fn history(&self) -> &[String] { + &self.history + } + + pub fn current_node(&self) -> Option<&str> { + self.history.last().map(|s| s.as_str()) + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(&self.data) + .map_err(|e| anyhow::anyhow!("Failed to serialize graph state: {}", e)) + } + + pub fn size_bytes(&self) -> usize { + self.to_json().map(|s| s.len()).unwrap_or(0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn deserializes_a_simple_graph() { + let yaml = r#" +name: test_graph +description: A test graph +version: "1.0" +start: node1 +nodes: + node1: + id: node1 + type: agent + agent: test_agent + prompt: "Hello {{name}}" + state_updates: + result: "{{output}}" + next: node2 + node2: + id: node2 + type: end + output: "{{result}}" +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(graph.name, "test_graph"); + assert_eq!(graph.start, "node1"); + assert_eq!(graph.nodes.len(), 2); + assert!(graph.has_node("node1")); + assert!(graph.has_node("node2")); + assert!(!graph.has_node("missing")); + + let node1 = graph.get_node("node1").unwrap(); + assert!(matches!(node1.node_type, NodeType::Agent(_))); + + let node2 = graph.get_node("node2").unwrap(); + match &node2.node_type { + NodeType::End(end) => assert_eq!(end.output, "{{result}}"), + _ => panic!("expected End variant"), + } + } + + #[test] + fn deserializes_every_node_type() { + let yaml = r#" +name: all_types +start: a +nodes: + a: + id: a + type: agent + agent: helper + prompt: hi + next: s + s: + id: s + type: script + script: scripts/decide.py + next: ap + ap: + id: ap + type: approval + question: ok? + options: [yes, no] + routes: + yes: i + no: e + i: + id: i + type: input + question: name? + state_updates: + name: "{{input}}" + next: e + e: + id: e + type: end + output: done +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + assert!(matches!( + graph.get_node("a").unwrap().node_type, + NodeType::Agent(_) + )); + assert!(matches!( + graph.get_node("s").unwrap().node_type, + NodeType::Script(_) + )); + assert!(matches!( + graph.get_node("ap").unwrap().node_type, + NodeType::Approval(_) + )); + assert!(matches!( + graph.get_node("i").unwrap().node_type, + NodeType::Input(_) + )); + assert!(matches!( + graph.get_node("e").unwrap().node_type, + NodeType::End(_) + )); + } + + #[test] + fn graph_settings_have_sensible_defaults() { + let yaml = "name: g\nstart: x\nnodes:\n x:\n id: x\n type: end\n output: ok\n"; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(graph.version, super::super::GRAPH_SCHEMA_VERSION); + assert_eq!( + graph.settings.max_loop_iterations, + super::super::DEFAULT_MAX_LOOP_ITERATIONS + ); + assert!(graph.settings.log_state_snapshots); + assert!(graph.settings.validate_before_run); + assert!(graph.settings.timeout.is_none()); + assert!(graph.initial_state.is_empty()); + assert_eq!(graph.description, ""); + } + + #[test] + fn input_node_with_all_fields() { + let yaml = r#" +id: get_key +type: input +question: "Enter your API key:" +default: "{{previous_api_key}}" +validation: "len(input) > 0" +state_updates: + api_key: "{{input}}" +next: configure +timeout: 300 +on_timeout: skip +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let input = match node.node_type { + NodeType::Input(i) => i, + _ => panic!("expected Input variant"), + }; + assert_eq!(input.question, "Enter your API key:"); + assert_eq!(input.default.as_deref(), Some("{{previous_api_key}}")); + assert_eq!(input.validation.as_deref(), Some("len(input) > 0")); + assert_eq!(input.timeout, Some(300)); + assert_eq!(input.on_timeout.as_deref(), Some("skip")); + let updates = input.state_updates.unwrap(); + assert_eq!( + updates.get("api_key").map(|s| s.as_str()), + Some("{{input}}") + ); + assert_eq!(node.next.as_deref(), Some("configure")); + } + + #[test] + fn input_node_with_minimal_fields() { + let yaml = r#" +id: ask +type: input +question: "Describe the feature:" +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let input = match node.node_type { + NodeType::Input(i) => i, + _ => panic!("expected Input variant"), + }; + assert_eq!(input.question, "Describe the feature:"); + assert!(input.default.is_none()); + assert!(input.validation.is_none()); + assert!(input.state_updates.is_none()); + assert!(input.timeout.is_none()); + assert!(input.on_timeout.is_none()); + assert!(node.next.is_none()); + } + + #[test] + fn script_node_defaults_timeout_to_30() { + let yaml = r#" +id: s +type: script +script: scripts/decide.py +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let script = match node.node_type { + NodeType::Script(s) => s, + _ => panic!("expected Script variant"), + }; + assert_eq!(script.timeout, 30); + assert!(script.fallback.is_none()); + assert!(script.state_updates.is_none()); + } + + #[test] + fn approval_node_carries_routes() { + let yaml = r#" +id: approve +type: approval +question: "Approve {{filename}}?" +options: [approve, reject, edit] +routes: + approve: apply + reject: end_reject + edit: edit_loop +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let approval = match node.node_type { + NodeType::Approval(a) => a, + _ => panic!("expected Approval variant"), + }; + assert_eq!(approval.options.len(), 3); + assert_eq!( + approval.routes.get("approve").map(|s| s.as_str()), + Some("apply") + ); + assert_eq!( + approval.routes.get("reject").map(|s| s.as_str()), + Some("end_reject") + ); + } + + #[test] + fn graph_state_basic_operations() { + let mut state = GraphState::new(HashMap::new()); + state.set("key1".to_string(), json!("value1")); + assert_eq!(state.get("key1"), Some(&json!("value1"))); + + state.visit_node("node1"); + state.visit_node("node2"); + state.visit_node("node1"); + + assert_eq!(state.loop_count("node1"), 2); + assert_eq!(state.loop_count("node2"), 1); + assert_eq!(state.loop_count("never"), 0); + assert_eq!(state.history().len(), 3); + assert_eq!(state.current_node(), Some("node1")); + } + + #[test] + fn graph_state_merge_overwrites_existing_keys() { + let mut state = GraphState::new(HashMap::new()); + state.set("existing".to_string(), json!("value")); + state.set("kept".to_string(), json!("untouched")); + + let mut obj = serde_json::Map::new(); + obj.insert("new_key".to_string(), json!("new_value")); + obj.insert("count".to_string(), json!(42)); + obj.insert("existing".to_string(), json!("replaced")); + + state.merge(&obj); + + assert_eq!(state.get("existing"), Some(&json!("replaced"))); + assert_eq!(state.get("kept"), Some(&json!("untouched"))); + assert_eq!(state.get("new_key"), Some(&json!("new_value"))); + assert_eq!(state.get("count"), Some(&json!(42))); + } + + #[test] + fn graph_state_serializes_to_json() { + let mut initial = HashMap::new(); + initial.insert("k".to_string(), json!("v")); + let state = GraphState::new(initial); + let serialized = state.to_json().unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap(); + assert_eq!(parsed.get("k"), Some(&json!("v"))); + assert!(state.size_bytes() > 0); + } + + #[test] + fn graph_state_initial_values_are_seeded() { + let mut initial = HashMap::new(); + initial.insert("user".to_string(), json!("alice")); + let state = GraphState::new(initial); + assert_eq!(state.get("user"), Some(&json!("alice"))); + assert!(state.history().is_empty()); + } +} diff --git a/src/main.rs b/src/main.rs index b453566..0e09e23 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod cli; mod client; mod config; mod function; +mod graph; mod rag; mod render; mod repl;