Files
coyote/src/supervisor/mod.rs
T

255 lines
6.9 KiB
Rust

pub mod escalation;
pub mod mailbox;
pub mod taskqueue;
use crate::utils::AbortSignal;
use fmt::{Debug, Formatter};
use mailbox::Inbox;
use taskqueue::TaskQueue;
use anyhow::{Result, bail};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use tokio::task::JoinHandle;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentExitStatus {
Completed,
Failed(String),
}
pub struct AgentResult {
pub id: String,
pub agent_name: String,
pub output: String,
pub exit_status: AgentExitStatus,
}
pub struct AgentHandle {
pub id: String,
pub agent_name: String,
pub depth: usize,
pub inbox: Arc<Inbox>,
pub abort_signal: AbortSignal,
pub join_handle: JoinHandle<Result<AgentResult>>,
}
pub struct Supervisor {
handles: HashMap<String, AgentHandle>,
task_queue: TaskQueue,
max_concurrent: usize,
max_depth: usize,
}
impl Supervisor {
pub fn new(max_concurrent: usize, max_depth: usize) -> Self {
Self {
handles: HashMap::new(),
task_queue: TaskQueue::new(),
max_concurrent,
max_depth,
}
}
pub fn active_count(&self) -> usize {
self.handles.len()
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
pub fn max_depth(&self) -> usize {
self.max_depth
}
pub fn task_queue(&self) -> &TaskQueue {
&self.task_queue
}
pub fn task_queue_mut(&mut self) -> &mut TaskQueue {
&mut self.task_queue
}
pub fn register(&mut self, handle: AgentHandle) -> Result<()> {
if self.handles.len() >= self.max_concurrent {
bail!(
"Cannot spawn agent: at capacity ({}/{})",
self.handles.len(),
self.max_concurrent
);
}
if handle.depth > self.max_depth {
bail!(
"Cannot spawn agent: max depth exceeded ({}/{})",
handle.depth,
self.max_depth
);
}
self.handles.insert(handle.id.clone(), handle);
Ok(())
}
pub fn is_finished(&self, id: &str) -> Option<bool> {
self.handles.get(id).map(|h| h.join_handle.is_finished())
}
pub fn take(&mut self, id: &str) -> Option<AgentHandle> {
self.handles.remove(id)
}
pub fn inbox(&self, id: &str) -> Option<&Arc<Inbox>> {
self.handles.get(id).map(|h| &h.inbox)
}
pub fn list_agents(&self) -> Vec<(&str, &str)> {
self.handles
.values()
.map(|h| (h.id.as_str(), h.agent_name.as_str()))
.collect()
}
pub fn cancel_all(&self) {
for handle in self.handles.values() {
handle.abort_signal.set_ctrlc();
}
}
}
impl Debug for Supervisor {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Supervisor")
.field("active_agents", &self.handles.len())
.field("max_concurrent", &self.max_concurrent)
.field("max_depth", &self.max_depth)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::create_abort_signal;
fn make_handle(id: &str, agent_name: &str, depth: usize) -> AgentHandle {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let join_handle = rt.spawn(async {
Ok(AgentResult {
id: "done".into(),
agent_name: "test".into(),
output: "result".into(),
exit_status: AgentExitStatus::Completed,
})
});
AgentHandle {
id: id.to_string(),
agent_name: agent_name.to_string(),
depth,
inbox: Arc::new(Inbox::new()),
abort_signal: create_abort_signal(),
join_handle,
}
}
#[test]
fn supervisor_new_empty() {
let sup = Supervisor::new(4, 3);
assert_eq!(sup.active_count(), 0);
assert_eq!(sup.max_concurrent(), 4);
assert_eq!(sup.max_depth(), 3);
}
#[test]
fn supervisor_register_increments_count() {
let mut sup = Supervisor::new(4, 3);
sup.register(make_handle("a1", "explore", 1)).unwrap();
assert_eq!(sup.active_count(), 1);
}
#[test]
fn supervisor_register_rejects_at_capacity() {
let mut sup = Supervisor::new(1, 3);
sup.register(make_handle("a1", "explore", 1)).unwrap();
let result = sup.register(make_handle("a2", "coder", 1));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("at capacity"));
}
#[test]
fn supervisor_register_rejects_exceeding_depth() {
let mut sup = Supervisor::new(4, 2);
let result = sup.register(make_handle("a1", "explore", 3));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max depth"));
}
#[test]
fn supervisor_register_allows_at_max_depth() {
let mut sup = Supervisor::new(4, 2);
sup.register(make_handle("a1", "explore", 2)).unwrap();
assert_eq!(sup.active_count(), 1);
}
#[test]
fn supervisor_take_removes_handle() {
let mut sup = Supervisor::new(4, 3);
sup.register(make_handle("a1", "explore", 1)).unwrap();
let taken = sup.take("a1");
assert!(taken.is_some());
assert_eq!(sup.active_count(), 0);
}
#[test]
fn supervisor_take_nonexistent_returns_none() {
let mut sup = Supervisor::new(4, 3);
assert!(sup.take("missing").is_none());
}
#[test]
fn supervisor_list_agents() {
let mut sup = Supervisor::new(4, 3);
sup.register(make_handle("a1", "explore", 1)).unwrap();
sup.register(make_handle("a2", "coder", 1)).unwrap();
let list = sup.list_agents();
assert_eq!(list.len(), 2);
let ids: Vec<&str> = list.iter().map(|(id, _)| *id).collect();
assert!(ids.contains(&"a1"));
assert!(ids.contains(&"a2"));
}
#[test]
fn supervisor_inbox_returns_handle_inbox() {
let mut sup = Supervisor::new(4, 3);
sup.register(make_handle("a1", "explore", 1)).unwrap();
assert!(sup.inbox("a1").is_some());
assert!(sup.inbox("missing").is_none());
}
#[test]
fn supervisor_task_queue_accessible() {
let mut sup = Supervisor::new(4, 3);
let id = sup
.task_queue_mut()
.create("task".into(), "desc".into(), None, None);
assert!(!id.is_empty());
assert_eq!(sup.task_queue().list().len(), 1);
}
#[test]
fn agent_exit_status_equality() {
assert_eq!(AgentExitStatus::Completed, AgentExitStatus::Completed);
assert_ne!(
AgentExitStatus::Completed,
AgentExitStatus::Failed("err".into())
);
assert_eq!(
AgentExitStatus::Failed("x".into()),
AgentExitStatus::Failed("x".into())
);
}
}