370 lines
11 KiB
Rust
370 lines
11 KiB
Rust
use super::state::StateManager;
|
|
use super::types::{ApprovalNode, InputNode};
|
|
use crate::config::RequestContext;
|
|
use crate::function::user_interaction::{USER_FUNCTION_PREFIX, handle_user_tool};
|
|
use anyhow::{Context, Result, anyhow, bail};
|
|
use serde_json::{Value, json};
|
|
use std::collections::HashMap;
|
|
|
|
const CHOICE_KEY: &str = "choice";
|
|
const INPUT_KEY: &str = "input";
|
|
|
|
pub struct ApprovalNodeExecutor;
|
|
|
|
impl ApprovalNodeExecutor {
|
|
pub async fn execute(
|
|
node: &ApprovalNode,
|
|
state_manager: &mut StateManager,
|
|
ctx: &mut RequestContext,
|
|
) -> Result<String> {
|
|
let question = state_manager
|
|
.interpolate(&node.question)
|
|
.context("Failed to interpolate approval question")?;
|
|
|
|
let response = handle_user_tool(
|
|
ctx,
|
|
&format!("{USER_FUNCTION_PREFIX}ask"),
|
|
&json!({ "question": question, "options": node.options }),
|
|
)
|
|
.await
|
|
.context("user__ask failed")?;
|
|
|
|
if let Some(err) = response.get("error").and_then(Value::as_str) {
|
|
bail!("Approval interaction failed: {err}");
|
|
}
|
|
|
|
let choice = response
|
|
.get("answer")
|
|
.and_then(Value::as_str)
|
|
.context("Approval response missing 'answer' field")?
|
|
.to_string();
|
|
|
|
apply_state_updates_with_var(&node.state_updates, state_manager, CHOICE_KEY, &choice);
|
|
|
|
resolve_approval_route(node, &choice)
|
|
}
|
|
}
|
|
|
|
pub struct InputNodeExecutor;
|
|
|
|
impl InputNodeExecutor {
|
|
pub async fn execute(
|
|
node: &InputNode,
|
|
node_next: Option<&str>,
|
|
state_manager: &mut StateManager,
|
|
ctx: &mut RequestContext,
|
|
) -> Result<String> {
|
|
let question = build_input_question(node, state_manager)?;
|
|
|
|
let response = handle_user_tool(
|
|
ctx,
|
|
&format!("{USER_FUNCTION_PREFIX}input"),
|
|
&json!({ "question": question }),
|
|
)
|
|
.await
|
|
.context("user__input failed")?;
|
|
|
|
if let Some(err) = response.get("error").and_then(Value::as_str) {
|
|
bail!("Input interaction failed: {err}");
|
|
}
|
|
|
|
let raw = response
|
|
.get("answer")
|
|
.and_then(Value::as_str)
|
|
.context("Input response missing 'answer' field")?
|
|
.to_string();
|
|
|
|
let input_text = if raw.is_empty() {
|
|
node.default
|
|
.as_ref()
|
|
.map(|t| state_manager.interpolate_lenient(t))
|
|
.unwrap_or_default()
|
|
} else {
|
|
raw
|
|
};
|
|
|
|
if let Some(expr) = &node.validation
|
|
&& !validate_length(&input_text, expr)?
|
|
{
|
|
bail!(
|
|
"Input failed validation '{}' (got {} chars)",
|
|
expr,
|
|
input_text.chars().count()
|
|
);
|
|
}
|
|
|
|
apply_state_updates_with_var(&node.state_updates, state_manager, INPUT_KEY, &input_text);
|
|
|
|
node_next
|
|
.map(String::from)
|
|
.ok_or_else(|| anyhow!("Input node has no `next` set"))
|
|
}
|
|
}
|
|
|
|
fn build_input_question(node: &InputNode, state_manager: &StateManager) -> Result<String> {
|
|
let mut question = state_manager
|
|
.interpolate(&node.question)
|
|
.context("Failed to interpolate input question")?;
|
|
|
|
if let Some(default_template) = &node.default {
|
|
let default = state_manager.interpolate_lenient(default_template);
|
|
if !default.is_empty() {
|
|
question = format!("{question} [default: {default}]");
|
|
}
|
|
}
|
|
|
|
Ok(question)
|
|
}
|
|
|
|
fn resolve_approval_route(node: &ApprovalNode, choice: &str) -> Result<String> {
|
|
if let Some(target) = node.routes.get(choice) {
|
|
return Ok(target.clone());
|
|
}
|
|
|
|
Ok(node.on_other.clone())
|
|
}
|
|
|
|
fn apply_state_updates_with_var(
|
|
updates: &Option<HashMap<String, String>>,
|
|
state_manager: &mut StateManager,
|
|
var_name: &str,
|
|
var_value: &str,
|
|
) {
|
|
let Some(updates) = updates else {
|
|
return;
|
|
};
|
|
let prev = state_manager.state().get(var_name).cloned();
|
|
state_manager
|
|
.state_mut()
|
|
.set(var_name.into(), Value::String(var_value.to_string()));
|
|
|
|
for (key, template) in updates {
|
|
let value = state_manager.interpolate_lenient(template);
|
|
state_manager
|
|
.state_mut()
|
|
.set(key.clone(), Value::String(value));
|
|
}
|
|
|
|
match prev {
|
|
Some(v) => state_manager.state_mut().set(var_name.into(), v),
|
|
None => {
|
|
state_manager.state_mut().set(var_name.into(), Value::Null);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Evaluate a `len(input) OP N` expression where OP is one of `>`, `>=`,
|
|
/// `<`, `<=`, `==`. Lengths are byte counts (matches Rust's `str::len`).
|
|
/// Other expressions are rejected at runtime.
|
|
fn validate_length(input: &str, expr: &str) -> Result<bool> {
|
|
let trimmed = expr.trim();
|
|
let after_len = trimmed
|
|
.strip_prefix("len(input)")
|
|
.map(str::trim)
|
|
.ok_or_else(|| {
|
|
anyhow!(
|
|
"Unsupported validation expression '{expr}'; only `len(input) OP N` is supported"
|
|
)
|
|
})?;
|
|
|
|
let (op, rhs_str) = if let Some(rest) = after_len.strip_prefix(">=") {
|
|
(">=", rest)
|
|
} else if let Some(rest) = after_len.strip_prefix("<=") {
|
|
("<=", rest)
|
|
} else if let Some(rest) = after_len.strip_prefix("==") {
|
|
("==", rest)
|
|
} else if let Some(rest) = after_len.strip_prefix('>') {
|
|
(">", rest)
|
|
} else if let Some(rest) = after_len.strip_prefix('<') {
|
|
("<", rest)
|
|
} else {
|
|
bail!("No comparison operator in validation expression '{expr}'");
|
|
};
|
|
|
|
let rhs: usize = rhs_str
|
|
.trim()
|
|
.parse()
|
|
.with_context(|| format!("Invalid right-hand side in validation '{expr}'"))?;
|
|
|
|
let len = input.len();
|
|
Ok(match op {
|
|
">=" => len >= rhs,
|
|
"<=" => len <= rhs,
|
|
"==" => len == rhs,
|
|
">" => len > rhs,
|
|
"<" => len < rhs,
|
|
_ => unreachable!(),
|
|
})
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::super::types::*;
|
|
use super::*;
|
|
use serde_json::json;
|
|
|
|
fn manager_with(pairs: &[(&str, Value)]) -> StateManager {
|
|
let mut map = HashMap::new();
|
|
for (k, v) in pairs {
|
|
map.insert((*k).into(), v.clone());
|
|
}
|
|
|
|
StateManager::new(map)
|
|
}
|
|
|
|
fn approval(options: &[&str], routes: &[(&str, &str)], on_other: &str) -> ApprovalNode {
|
|
let mut r = HashMap::new();
|
|
for (k, v) in routes {
|
|
r.insert((*k).into(), (*v).into());
|
|
}
|
|
|
|
ApprovalNode {
|
|
question: "?".into(),
|
|
options: options.iter().map(|s| (*s).into()).collect(),
|
|
routes: r,
|
|
on_other: on_other.into(),
|
|
state_updates: None,
|
|
}
|
|
}
|
|
|
|
fn input(question: &str) -> InputNode {
|
|
InputNode {
|
|
question: question.into(),
|
|
default: None,
|
|
validation: None,
|
|
state_updates: None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn validate_length_supports_all_comparison_operators() {
|
|
assert!(validate_length("hello", "len(input) > 0").unwrap());
|
|
assert!(!validate_length("", "len(input) > 0").unwrap());
|
|
assert!(validate_length("hello", "len(input) >= 5").unwrap());
|
|
assert!(!validate_length("hi", "len(input) >= 5").unwrap());
|
|
assert!(validate_length("hello", "len(input) < 10").unwrap());
|
|
assert!(!validate_length("hello world!", "len(input) < 10").unwrap());
|
|
assert!(validate_length("hi", "len(input) <= 2").unwrap());
|
|
assert!(!validate_length("hello", "len(input) <= 2").unwrap());
|
|
assert!(validate_length("hello", "len(input) == 5").unwrap());
|
|
assert!(!validate_length("hello", "len(input) == 3").unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_length_handles_whitespace() {
|
|
assert!(validate_length("hi", " len(input) >= 1 ").unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_length_rejects_unsupported_expressions() {
|
|
assert!(validate_length("x", "matches /[a-z]+/").is_err());
|
|
assert!(validate_length("x", "len(input)").is_err());
|
|
assert!(validate_length("x", "len(input) >").is_err());
|
|
assert!(validate_length("x", "len(input) >= abc").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn approval_route_lookup_returns_target_on_match() {
|
|
let node = approval(
|
|
&["yes", "no"],
|
|
&[("yes", "deploy"), ("no", "cancel")],
|
|
"clarify",
|
|
);
|
|
|
|
assert_eq!(resolve_approval_route(&node, "yes").unwrap(), "deploy");
|
|
assert_eq!(resolve_approval_route(&node, "no").unwrap(), "cancel");
|
|
}
|
|
|
|
#[test]
|
|
fn approval_route_lookup_falls_back_to_on_other_for_unknown_choice() {
|
|
let node = approval(
|
|
&["yes", "no"],
|
|
&[("yes", "deploy"), ("no", "cancel")],
|
|
"clarify",
|
|
);
|
|
|
|
assert_eq!(resolve_approval_route(&node, "maybe").unwrap(), "clarify");
|
|
assert_eq!(
|
|
resolve_approval_route(&node, "free-form text").unwrap(),
|
|
"clarify"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn state_updates_expose_choice_during_evaluation_only() {
|
|
let mut updates = HashMap::new();
|
|
updates.insert("decision".into(), "{{choice}}".into());
|
|
let mut state = manager_with(&[]);
|
|
|
|
apply_state_updates_with_var(&Some(updates), &mut state, CHOICE_KEY, "approve");
|
|
|
|
assert_eq!(state.state().get("decision"), Some(&json!("approve")));
|
|
assert_eq!(state.state().get(CHOICE_KEY), Some(&Value::Null));
|
|
}
|
|
|
|
#[test]
|
|
fn state_updates_preserve_pre_existing_var_value() {
|
|
let mut updates = HashMap::new();
|
|
updates.insert("decision".into(), "{{choice}}".into());
|
|
let mut state = manager_with(&[("choice", json!("preserved"))]);
|
|
|
|
apply_state_updates_with_var(&Some(updates), &mut state, CHOICE_KEY, "approve");
|
|
|
|
assert_eq!(state.state().get("decision"), Some(&json!("approve")));
|
|
assert_eq!(state.state().get(CHOICE_KEY), Some(&json!("preserved")));
|
|
}
|
|
|
|
#[test]
|
|
fn state_updates_for_input_use_input_key() {
|
|
let mut updates = HashMap::new();
|
|
updates.insert("api_key".into(), "{{input}}".into());
|
|
let mut state = manager_with(&[]);
|
|
|
|
apply_state_updates_with_var(&Some(updates), &mut state, INPUT_KEY, "sk-12345");
|
|
|
|
assert_eq!(state.state().get("api_key"), Some(&json!("sk-12345")));
|
|
assert_eq!(state.state().get(INPUT_KEY), Some(&Value::Null));
|
|
}
|
|
|
|
#[test]
|
|
fn input_question_appends_default_when_present() {
|
|
let state = manager_with(&[("name", json!("alice"))]);
|
|
let mut node = input("Hi, what's your name?");
|
|
node.default = Some("{{name}}".into());
|
|
|
|
let q = build_input_question(&node, &state).unwrap();
|
|
|
|
assert_eq!(q, "Hi, what's your name? [default: alice]");
|
|
}
|
|
|
|
#[test]
|
|
fn input_question_omits_default_when_blank_after_interpolation() {
|
|
let state = manager_with(&[]);
|
|
let mut node = input("Enter value:");
|
|
node.default = Some("{{missing}}".into());
|
|
|
|
let q = build_input_question(&node, &state).unwrap();
|
|
|
|
assert_eq!(q, "Enter value:");
|
|
}
|
|
|
|
#[test]
|
|
fn input_question_uses_no_default_when_field_absent() {
|
|
let state = manager_with(&[]);
|
|
let node = input("Enter value:");
|
|
|
|
let q = build_input_question(&node, &state).unwrap();
|
|
|
|
assert_eq!(q, "Enter value:");
|
|
}
|
|
|
|
#[test]
|
|
fn no_state_updates_means_var_never_appears_in_state() {
|
|
let mut state = manager_with(&[]);
|
|
|
|
apply_state_updates_with_var(&None, &mut state, CHOICE_KEY, "approve");
|
|
|
|
assert!(state.state().get(CHOICE_KEY).is_none());
|
|
}
|
|
}
|