style: Cleaned up all graph agent code
This commit is contained in:
+197
-175
@@ -1,5 +1,3 @@
|
||||
//! YAML parsing for graph definitions.
|
||||
|
||||
use super::types::Graph;
|
||||
use crate::config::paths;
|
||||
use anyhow::{Context, Error, Result, anyhow, bail};
|
||||
@@ -8,9 +6,6 @@ use std::path::{Path, PathBuf};
|
||||
|
||||
const SUPPORTED_VERSIONS: &[&str] = &["1.0"];
|
||||
|
||||
/// Parser for graph YAML files. The `base_dir` is used to resolve relative
|
||||
/// paths passed to [`GraphParser::load_from_file`], and is typically an
|
||||
/// agent directory.
|
||||
pub struct GraphParser {
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
@@ -22,8 +17,6 @@ impl GraphParser {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load and validate a graph from a YAML file. Relative paths are
|
||||
/// resolved against `base_dir`.
|
||||
pub fn load_from_file(&self, path: impl AsRef<Path>) -> Result<Graph> {
|
||||
let path = path.as_ref();
|
||||
let full_path = if path.is_absolute() {
|
||||
@@ -39,7 +32,6 @@ impl GraphParser {
|
||||
.with_context(|| format!("Failed to parse graph file at '{}'", full_path.display()))
|
||||
}
|
||||
|
||||
/// Load and validate a graph from a YAML string.
|
||||
pub fn load_from_string(&self, yaml: &str) -> Result<Graph> {
|
||||
let mut graph: Graph = serde_yaml::from_str(yaml).map_err(enhance_yaml_error)?;
|
||||
|
||||
@@ -71,6 +63,7 @@ fn validate_schema_version(version: &str) -> Result<()> {
|
||||
SUPPORTED_VERSIONS.join(", ")
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -125,16 +118,19 @@ fn enhance_yaml_error(error: serde_yaml::Error) -> Error {
|
||||
anyhow!("YAML parsing error: {}{}", msg, hint)
|
||||
}
|
||||
|
||||
/// Returns true if the named agent has a `graph.yaml` in its data directory.
|
||||
pub fn agent_has_graph(agent_name: &str) -> bool {
|
||||
paths::agent_graph_file(agent_name).exists()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::GRAPH_SCHEMA_VERSION;
|
||||
use super::super::types::NodeType;
|
||||
use super::*;
|
||||
use std::env;
|
||||
use indoc::formatdoc;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::{env, fs, process};
|
||||
|
||||
fn parser() -> GraphParser {
|
||||
GraphParser::new(env::current_dir().unwrap())
|
||||
@@ -142,23 +138,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parses_a_simple_graph() {
|
||||
let yaml = r#"
|
||||
name: simple_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
id: node1
|
||||
type: agent
|
||||
agent: test_agent
|
||||
prompt: "Hello world"
|
||||
next: node2
|
||||
node2:
|
||||
id: node2
|
||||
type: end
|
||||
output: done
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: simple_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
id: node1
|
||||
type: agent
|
||||
agent: test_agent
|
||||
prompt: "Hello world"
|
||||
next: node2
|
||||
node2:
|
||||
id: node2
|
||||
type: end
|
||||
output: done
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.name, "simple_graph");
|
||||
assert_eq!(graph.start, "node1");
|
||||
assert_eq!(graph.nodes.len(), 2);
|
||||
@@ -170,36 +168,40 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn auto_fills_node_ids_from_keys() {
|
||||
let yaml = r#"
|
||||
name: auto_id_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: agent
|
||||
agent: test_agent
|
||||
prompt: Test
|
||||
next: node2
|
||||
node2:
|
||||
type: end
|
||||
output: done
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: auto_id_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: agent
|
||||
agent: test_agent
|
||||
prompt: Test
|
||||
next: node2
|
||||
node2:
|
||||
type: end
|
||||
output: done
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.nodes.get("node1").unwrap().id, "node1");
|
||||
assert_eq!(graph.nodes.get("node2").unwrap().id, "node2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_start_node() {
|
||||
let yaml = r#"
|
||||
name: bad_graph
|
||||
version: "1.0"
|
||||
start: nonexistent
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: bad_graph
|
||||
version: "1.0"
|
||||
start: nonexistent
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(
|
||||
err.contains("Start node 'nonexistent' not found"),
|
||||
"got: {err}"
|
||||
@@ -208,41 +210,47 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_graph_name() {
|
||||
let yaml = r#"
|
||||
name: ""
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: ""
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("non-empty 'name'"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_no_nodes() {
|
||||
let yaml = r#"
|
||||
name: empty_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes: {}
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: empty_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes: {}
|
||||
"#, "{}"};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("no nodes defined"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unsupported_version() {
|
||||
let yaml = r#"
|
||||
name: future_graph
|
||||
version: "2.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: future_graph
|
||||
version: "2.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(
|
||||
err.contains("Unsupported graph schema version"),
|
||||
"got: {err}"
|
||||
@@ -251,42 +259,46 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn rejects_node_id_mismatch() {
|
||||
let yaml = r#"
|
||||
name: mismatch_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
id: different_id
|
||||
type: end
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: mismatch_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
id: different_id
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("Node ID mismatch"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_approval_node_with_routes() {
|
||||
let yaml = r#"
|
||||
name: approval_graph
|
||||
version: "1.0"
|
||||
start: approval1
|
||||
nodes:
|
||||
approval1:
|
||||
type: approval
|
||||
question: "Proceed with deployment?"
|
||||
options:
|
||||
- "Yes"
|
||||
- "No"
|
||||
routes:
|
||||
"Yes": deploy
|
||||
"No": cancel
|
||||
on_other: cancel
|
||||
deploy:
|
||||
type: end
|
||||
cancel:
|
||||
type: end
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: approval_graph
|
||||
version: "1.0"
|
||||
start: approval1
|
||||
nodes:
|
||||
approval1:
|
||||
type: approval
|
||||
question: "Proceed with deployment?"
|
||||
options:
|
||||
- "Yes"
|
||||
- "No"
|
||||
routes:
|
||||
"Yes": deploy
|
||||
"No": cancel
|
||||
on_other: cancel
|
||||
deploy:
|
||||
type: end
|
||||
cancel:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
let approval = graph.nodes.get("approval1").unwrap();
|
||||
match &approval.node_type {
|
||||
NodeType::Approval(a) => {
|
||||
@@ -300,19 +312,21 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn parses_settings_overrides() {
|
||||
let yaml = r#"
|
||||
name: settings_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
settings:
|
||||
max_loop_iterations: 50
|
||||
timeout: 300
|
||||
log_state_snapshots: false
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: settings_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
settings:
|
||||
max_loop_iterations: 50
|
||||
timeout: 300
|
||||
log_state_snapshots: false
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.settings.max_loop_iterations, 50);
|
||||
assert_eq!(graph.settings.timeout, Some(300));
|
||||
assert!(!graph.settings.log_state_snapshots);
|
||||
@@ -321,19 +335,21 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn parses_initial_state() {
|
||||
let yaml = r#"
|
||||
name: state_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
initial_state:
|
||||
user_name: "Alice"
|
||||
count: 42
|
||||
enabled: true
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: state_graph
|
||||
version: "1.0"
|
||||
start: node1
|
||||
initial_state:
|
||||
user_name: "Alice"
|
||||
count: 42
|
||||
enabled: true
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.initial_state.len(), 3);
|
||||
assert_eq!(graph.initial_state.get("user_name").unwrap(), "Alice");
|
||||
assert_eq!(
|
||||
@@ -348,28 +364,32 @@ nodes:
|
||||
|
||||
#[test]
|
||||
fn uses_default_version_when_absent() {
|
||||
let yaml = r#"
|
||||
name: no_version
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let graph = parser().load_from_string(yaml).unwrap();
|
||||
assert_eq!(graph.version, super::super::GRAPH_SCHEMA_VERSION);
|
||||
let yaml = formatdoc! {r#"
|
||||
name: no_version
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let graph = parser().load_from_string(&yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.version, GRAPH_SCHEMA_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unknown_node_type_with_hint() {
|
||||
let yaml = r#"
|
||||
name: bad_type
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: nonsense
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: bad_type
|
||||
version: "1.0"
|
||||
start: node1
|
||||
nodes:
|
||||
node1:
|
||||
type: nonsense
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(
|
||||
err.contains("Valid node types") || err.contains("unknown variant"),
|
||||
"got: {err}"
|
||||
@@ -379,49 +399,50 @@ nodes:
|
||||
#[test]
|
||||
fn rejects_malformed_yaml() {
|
||||
let yaml = "name: bad\n bad: indent\nstart: a";
|
||||
|
||||
let result = parser().load_from_string(yaml);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_required_fields_have_a_hint() {
|
||||
let yaml = r#"
|
||||
name: missing_start
|
||||
version: "1.0"
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#;
|
||||
let err = parser().load_from_string(yaml).unwrap_err().to_string();
|
||||
let yaml = formatdoc! {r#"
|
||||
name: missing_start
|
||||
version: "1.0"
|
||||
nodes:
|
||||
node1:
|
||||
type: end
|
||||
"#};
|
||||
|
||||
let err = parser().load_from_string(&yaml).unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("Hint"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_from_file_reads_disk() {
|
||||
use std::io::Write;
|
||||
let dir = env::temp_dir();
|
||||
let path = dir.join(format!(
|
||||
"loki_graph_parser_test_{}.yaml",
|
||||
std::process::id()
|
||||
));
|
||||
let yaml = r#"
|
||||
name: disk_graph
|
||||
version: "1.0"
|
||||
start: only
|
||||
nodes:
|
||||
only:
|
||||
type: end
|
||||
output: ok
|
||||
"#;
|
||||
let path = dir.join(format!("loki_graph_parser_test_{}.yaml", process::id()));
|
||||
let yaml = formatdoc! {r#"
|
||||
name: disk_graph
|
||||
version: "1.0"
|
||||
start: only
|
||||
nodes:
|
||||
only:
|
||||
type: end
|
||||
output: ok
|
||||
"#};
|
||||
{
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
let mut f = File::create(&path).unwrap();
|
||||
f.write_all(yaml.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
let graph = GraphParser::new(dir).load_from_file(&path).unwrap();
|
||||
|
||||
assert_eq!(graph.name, "disk_graph");
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
let _ = fs::remove_file(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -430,6 +451,7 @@ nodes:
|
||||
.load_from_file("/definitely/not/a/real/path/to_any_graph.yaml")
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("Failed to read graph file"), "got: {err}");
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user