feat: Implemented a built-in task management system to help smaller LLMs complete larger multistep tasks and minimize context drift
This commit is contained in:
+108
-1
@@ -1,3 +1,4 @@
|
||||
use super::todo::TodoList;
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
@@ -14,6 +15,18 @@ use serde::{Deserialize, Serialize};
|
||||
use std::{ffi::OsStr, path::Path};
|
||||
|
||||
const DEFAULT_AGENT_NAME: &str = "rag";
|
||||
const DEFAULT_TODO_INSTRUCTIONS: &str = "\
|
||||
\n## Task Tracking\n\
|
||||
You have built-in task tracking tools. Use them to track your progress:\n\
|
||||
- `todo__init`: Initialize a todo list with a goal. Call this at the start of every multi-step task.\n\
|
||||
- `todo__add`: Add individual tasks. Add all planned steps before starting work.\n\
|
||||
- `todo__done`: Mark a task done by id. Call this immediately after completing each step.\n\
|
||||
- `todo__list`: Show the current todo list.\n\
|
||||
\n\
|
||||
RULES:\n\
|
||||
- Always create a todo list before starting work.\n\
|
||||
- Mark each task done as soon as you finish it — do not batch.\n\
|
||||
- If you stop with incomplete tasks, the system will automatically prompt you to continue.";
|
||||
|
||||
pub type AgentVariables = IndexMap<String, String>;
|
||||
|
||||
@@ -33,6 +46,9 @@ pub struct Agent {
|
||||
rag: Option<Arc<Rag>>,
|
||||
model: Model,
|
||||
vault: GlobalVault,
|
||||
todo_list: TodoList,
|
||||
continuation_count: usize,
|
||||
last_continuation_response: Option<String>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
@@ -188,6 +204,10 @@ impl Agent {
|
||||
None
|
||||
};
|
||||
|
||||
if agent_config.auto_continue {
|
||||
functions.append_todo_functions();
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
config: agent_config,
|
||||
@@ -199,6 +219,9 @@ impl Agent {
|
||||
rag,
|
||||
model,
|
||||
vault: Arc::clone(&config.read().vault),
|
||||
todo_list: TodoList::default(),
|
||||
continuation_count: 0,
|
||||
last_continuation_response: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -309,11 +332,16 @@ impl Agent {
|
||||
}
|
||||
|
||||
pub fn interpolated_instructions(&self) -> String {
|
||||
let output = self
|
||||
let mut output = self
|
||||
.session_dynamic_instructions
|
||||
.clone()
|
||||
.or_else(|| self.shared_dynamic_instructions.clone())
|
||||
.unwrap_or_else(|| self.config.instructions.clone());
|
||||
|
||||
if self.config.auto_continue && self.config.inject_todo_instructions {
|
||||
output.push_str(DEFAULT_TODO_INSTRUCTIONS);
|
||||
}
|
||||
|
||||
self.interpolate_text(&output)
|
||||
}
|
||||
|
||||
@@ -376,6 +404,67 @@ impl Agent {
|
||||
self.session_dynamic_instructions = None;
|
||||
}
|
||||
|
||||
pub fn auto_continue_enabled(&self) -> bool {
|
||||
self.config.auto_continue
|
||||
}
|
||||
|
||||
pub fn max_auto_continues(&self) -> usize {
|
||||
self.config.max_auto_continues
|
||||
}
|
||||
|
||||
pub fn continuation_count(&self) -> usize {
|
||||
self.continuation_count
|
||||
}
|
||||
|
||||
pub fn increment_continuation(&mut self) {
|
||||
self.continuation_count += 1;
|
||||
}
|
||||
|
||||
pub fn reset_continuation(&mut self) {
|
||||
self.continuation_count = 0;
|
||||
self.last_continuation_response = None;
|
||||
}
|
||||
|
||||
pub fn is_stale_response(&self, response: &str) -> bool {
|
||||
self.last_continuation_response
|
||||
.as_ref()
|
||||
.is_some_and(|last| last == response)
|
||||
}
|
||||
|
||||
pub fn set_last_continuation_response(&mut self, response: String) {
|
||||
self.last_continuation_response = Some(response);
|
||||
}
|
||||
|
||||
pub fn todo_list(&self) -> &TodoList {
|
||||
&self.todo_list
|
||||
}
|
||||
|
||||
pub fn init_todo_list(&mut self, goal: &str) {
|
||||
self.todo_list = TodoList::new(goal);
|
||||
}
|
||||
|
||||
pub fn add_todo(&mut self, task: &str) -> usize {
|
||||
self.todo_list.add(task)
|
||||
}
|
||||
|
||||
pub fn mark_todo_done(&mut self, id: usize) -> bool {
|
||||
self.todo_list.mark_done(id)
|
||||
}
|
||||
|
||||
pub fn continuation_prompt(&self) -> String {
|
||||
self.config.continuation_prompt.clone().unwrap_or_else(|| {
|
||||
"[SYSTEM REMINDER - TODO CONTINUATION]\n\
|
||||
You have incomplete tasks in your todo list. \
|
||||
Continue with the next pending item. \
|
||||
Call tools immediately. Do not explain what you will do."
|
||||
.to_string()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compression_threshold(&self) -> Option<usize> {
|
||||
self.config.compression_threshold
|
||||
}
|
||||
|
||||
pub fn is_dynamic_instructions(&self) -> bool {
|
||||
self.config.dynamic_instructions
|
||||
}
|
||||
@@ -498,6 +587,14 @@ pub struct AgentConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub agent_session: Option<String>,
|
||||
#[serde(default)]
|
||||
pub auto_continue: bool,
|
||||
#[serde(default = "default_max_auto_continues")]
|
||||
pub max_auto_continues: usize,
|
||||
#[serde(default = "default_true")]
|
||||
pub inject_todo_instructions: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub compression_threshold: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
#[serde(default)]
|
||||
pub version: String,
|
||||
@@ -505,6 +602,8 @@ pub struct AgentConfig {
|
||||
pub mcp_servers: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub global_tools: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub continuation_prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
pub instructions: String,
|
||||
#[serde(default)]
|
||||
@@ -517,6 +616,14 @@ pub struct AgentConfig {
|
||||
pub documents: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_max_auto_continues() -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl AgentConfig {
|
||||
pub fn load(path: &Path) -> Result<Self> {
|
||||
let contents = read_to_string(path)
|
||||
|
||||
+12
-1
@@ -3,6 +3,7 @@ mod input;
|
||||
mod macros;
|
||||
mod role;
|
||||
mod session;
|
||||
pub(crate) mod todo;
|
||||
|
||||
pub use self::agent::{Agent, AgentVariables, complete_agent_variables, list_agents};
|
||||
pub use self::input::Input;
|
||||
@@ -1573,8 +1574,18 @@ impl Config {
|
||||
.summary_context_prompt
|
||||
.clone()
|
||||
.unwrap_or_else(|| SUMMARY_CONTEXT_PROMPT.into());
|
||||
|
||||
let todo_prefix = config
|
||||
.read()
|
||||
.agent
|
||||
.as_ref()
|
||||
.map(|agent| agent.todo_list())
|
||||
.filter(|todos| !todos.is_empty())
|
||||
.map(|todos| format!("[ACTIVE TODO LIST]\n{}\n\n", todos.render_for_model()))
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(session) = config.write().session.as_mut() {
|
||||
session.compress(format!("{summary_context_prompt}{summary}"));
|
||||
session.compress(format!("{todo_prefix}{summary_context_prompt}{summary}"));
|
||||
}
|
||||
config.write().discontinuous_last_message();
|
||||
Ok(())
|
||||
|
||||
@@ -299,6 +299,9 @@ impl Session {
|
||||
self.role_prompt = agent.interpolated_instructions();
|
||||
self.agent_variables = agent.variables().clone();
|
||||
self.agent_instructions = self.role_prompt.clone();
|
||||
if let Some(threshold) = agent.compression_threshold() {
|
||||
self.set_compression_threshold(Some(threshold));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn agent_variables(&self) -> &AgentVariables {
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TodoStatus {
|
||||
Pending,
|
||||
Done,
|
||||
}
|
||||
|
||||
impl TodoStatus {
|
||||
fn icon(&self) -> &'static str {
|
||||
match self {
|
||||
TodoStatus::Pending => "○",
|
||||
TodoStatus::Done => "✓",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TodoItem {
|
||||
pub id: usize,
|
||||
#[serde(alias = "description")]
|
||||
pub desc: String,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TodoList {
|
||||
#[serde(default)]
|
||||
pub goal: String,
|
||||
#[serde(default)]
|
||||
pub todos: Vec<TodoItem>,
|
||||
}
|
||||
|
||||
impl TodoList {
|
||||
pub fn new(goal: &str) -> Self {
|
||||
Self {
|
||||
goal: goal.to_string(),
|
||||
todos: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, task: &str) -> usize {
|
||||
let id = self.todos.iter().map(|t| t.id).max().unwrap_or(0) + 1;
|
||||
self.todos.push(TodoItem {
|
||||
id,
|
||||
desc: task.to_string(),
|
||||
done: false,
|
||||
});
|
||||
id
|
||||
}
|
||||
|
||||
pub fn mark_done(&mut self, id: usize) -> bool {
|
||||
if let Some(item) = self.todos.iter_mut().find(|t| t.id == id) {
|
||||
item.done = true;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_incomplete(&self) -> bool {
|
||||
self.todos.iter().any(|item| !item.done)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.todos.is_empty()
|
||||
}
|
||||
|
||||
pub fn render_for_model(&self) -> String {
|
||||
let mut lines = Vec::new();
|
||||
if !self.goal.is_empty() {
|
||||
lines.push(format!("Goal: {}", self.goal));
|
||||
}
|
||||
lines.push(format!(
|
||||
"Progress: {}/{} completed",
|
||||
self.completed_count(),
|
||||
self.todos.len()
|
||||
));
|
||||
for item in &self.todos {
|
||||
let status = if item.done {
|
||||
TodoStatus::Done
|
||||
} else {
|
||||
TodoStatus::Pending
|
||||
};
|
||||
lines.push(format!(" {} {}. {}", status.icon(), item.id, item.desc));
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
pub fn incomplete_count(&self) -> usize {
|
||||
self.todos.iter().filter(|item| !item.done).count()
|
||||
}
|
||||
|
||||
pub fn completed_count(&self) -> usize {
|
||||
self.todos.iter().filter(|item| item.done).count()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_and_add() {
|
||||
let mut list = TodoList::new("Map Labs");
|
||||
assert_eq!(list.add("Discover"), 1);
|
||||
assert_eq!(list.add("Map columns"), 2);
|
||||
assert_eq!(list.todos.len(), 2);
|
||||
assert!(list.has_incomplete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mark_done() {
|
||||
let mut list = TodoList::new("Test");
|
||||
list.add("Task 1");
|
||||
list.add("Task 2");
|
||||
assert!(list.mark_done(1));
|
||||
assert!(!list.mark_done(99));
|
||||
assert_eq!(list.completed_count(), 1);
|
||||
assert_eq!(list.incomplete_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_list() {
|
||||
let list = TodoList::default();
|
||||
assert!(!list.has_incomplete());
|
||||
assert!(list.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_done() {
|
||||
let mut list = TodoList::new("Test");
|
||||
list.add("Done task");
|
||||
list.mark_done(1);
|
||||
assert!(!list.has_incomplete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_for_model() {
|
||||
let mut list = TodoList::new("Map Labs");
|
||||
list.add("Discover");
|
||||
list.add("Map");
|
||||
list.mark_done(1);
|
||||
let rendered = list.render_for_model();
|
||||
assert!(rendered.contains("Goal: Map Labs"));
|
||||
assert!(rendered.contains("Progress: 1/2 completed"));
|
||||
assert!(rendered.contains("✓ 1. Discover"));
|
||||
assert!(rendered.contains("○ 2. Map"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_roundtrip() {
|
||||
let mut list = TodoList::new("Roundtrip");
|
||||
list.add("Step 1");
|
||||
list.add("Step 2");
|
||||
list.mark_done(1);
|
||||
let json = serde_json::to_string(&list).unwrap();
|
||||
let deserialized: TodoList = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.goal, "Roundtrip");
|
||||
assert_eq!(deserialized.todos.len(), 2);
|
||||
assert!(deserialized.todos[0].done);
|
||||
assert!(!deserialized.todos[1].done);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user