Baseline project

This commit is contained in:
2025-10-07 10:45:42 -06:00
parent 88288a98b6
commit 650dbd92e0
54 changed files with 18982 additions and 0 deletions
+570
View File
@@ -0,0 +1,570 @@
use super::*;
use crate::{
client::Model,
function::{run_llm_function, Functions},
};
use anyhow::{Context, Result};
use inquire::{validator::Validation, Text};
use std::{fs::read_to_string, path::Path};
use serde::{Deserialize, Serialize};
const DEFAULT_AGENT_NAME: &str = "rag";
pub type AgentVariables = IndexMap<String, String>;
#[derive(Debug, Clone)]
pub struct Agent {
name: String,
config: AgentConfig,
shared_variables: AgentVariables,
session_variables: Option<AgentVariables>,
shared_dynamic_instructions: Option<String>,
session_dynamic_instructions: Option<String>,
functions: Functions,
rag: Option<Arc<Rag>>,
model: Model,
}
impl Agent {
pub async fn init(
config: &GlobalConfig,
name: &str,
abort_signal: AbortSignal,
) -> Result<Self> {
let agent_data_dir = Config::agent_data_dir(name);
let loaders = config.read().document_loaders.clone();
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME);
let config_path = Config::agent_config_file(name);
let mut agent_config = if config_path.exists() {
AgentConfig::load(&config_path)?
} else {
bail!("Agent config file not found at '{}'", config_path.display())
};
let mut functions = Functions::init_agent(name, &agent_config.global_tools)?;
config.write().functions.clear_mcp_meta_functions();
let mcp_servers =
(!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(","));
let registry = config
.write()
.mcp_registry
.take()
.expect("MCP registry should be initialized");
let new_mcp_registry =
McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?;
if !new_mcp_registry.is_empty() {
functions.append_mcp_meta_functions(new_mcp_registry.list_servers());
}
config.write().mcp_registry = Some(new_mcp_registry);
agent_config.replace_tools_placeholder(&functions);
agent_config.load_envs(&config.read());
let model = {
let config = config.read();
match agent_config.model_id.as_ref() {
Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?,
None => {
if agent_config.temperature.is_none() {
agent_config.temperature = config.temperature;
}
if agent_config.top_p.is_none() {
agent_config.top_p = config.top_p;
}
config.current_model().clone()
}
}
};
let rag = if rag_path.exists() {
Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?))
} else if !agent_config.documents.is_empty() && !config.read().info_flag {
let mut ans = false;
if *IS_STDOUT_TERMINAL {
ans = Confirm::new("The agent has documents attached, init RAG?")
.with_default(true)
.prompt()?;
}
if ans {
let mut document_paths = vec![];
for path in &agent_config.documents {
if is_url(path) {
document_paths.push(path.to_string());
} else if is_loader_protocol(&loaders, path) {
let (protocol, document_path) = path
.split_once(':')
.with_context(|| "Invalid loader protocol path")?;
let resolved_path = resolve_home_dir(document_path);
let new_path = if Path::new(&resolved_path).is_relative() {
safe_join_path(&agent_data_dir, resolved_path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
} else {
PathBuf::from(&resolved_path)
};
document_paths.push(format!("{}:{}", protocol, new_path.display()));
} else if Path::new(&resolve_home_dir(path)).is_relative() {
let new_path = safe_join_path(&agent_data_dir, path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
document_paths.push(new_path.display().to_string())
} else {
document_paths.push(path.to_string())
}
}
let rag =
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
Some(Arc::new(rag))
} else {
None
}
} else {
None
};
Ok(Self {
name: name.to_string(),
config: agent_config,
shared_variables: Default::default(),
session_variables: None,
shared_dynamic_instructions: None,
session_dynamic_instructions: None,
functions,
rag,
model,
})
}
pub fn init_agent_variables(
agent_variables: &[AgentVariable],
no_interaction: bool,
) -> Result<AgentVariables> {
let mut output = IndexMap::new();
if agent_variables.is_empty() {
return Ok(output);
}
let mut printed = false;
let mut unset_variables = vec![];
for agent_variable in agent_variables {
let key = agent_variable.name.clone();
if let Some(value) = agent_variable.default.clone() {
output.insert(key, value);
continue;
}
if no_interaction {
continue;
}
if *IS_STDOUT_TERMINAL {
if !printed {
println!("⚙ Init agent variables...");
printed = true;
}
let value = Text::new(&format!(
"{} ({}):",
agent_variable.name, agent_variable.description
))
.with_validator(|input: &str| {
if input.trim().is_empty() {
Ok(Validation::Invalid("This field is required".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
output.insert(key, value);
} else {
unset_variables.push(agent_variable)
}
}
if !unset_variables.is_empty() {
bail!(
"The following agent variables are required:\n{}",
unset_variables
.iter()
.map(|v| format!(" - {}: {}", v.name, v.description))
.collect::<Vec<_>>()
.join("\n")
)
}
Ok(output)
}
pub fn export(&self) -> Result<String> {
let mut value = json!({});
value["name"] = json!(self.name());
let variables = self.variables();
if !variables.is_empty() {
value["variables"] = serde_json::to_value(variables)?;
}
value["config"] = json!(self.config);
let mut config = self.config.clone();
config.instructions = self.interpolated_instructions();
value["definition"] = json!(config);
value["data_dir"] = Config::agent_data_dir(&self.name)
.display()
.to_string()
.into();
value["config_file"] = Config::agent_config_file(&self.name)
.display()
.to_string()
.into();
let data = serde_yaml::to_string(&value)?;
Ok(data)
}
pub fn banner(&self) -> String {
self.config.banner()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn functions(&self) -> &Functions {
&self.functions
}
pub fn rag(&self) -> Option<Arc<Rag>> {
self.rag.clone()
}
pub fn conversation_starters(&self) -> &[String] {
&self.config.conversation_starters
}
pub fn interpolated_instructions(&self) -> String {
let mut output = self
.session_dynamic_instructions
.clone()
.or_else(|| self.shared_dynamic_instructions.clone())
.unwrap_or_else(|| self.config.instructions.clone());
for (k, v) in self.variables() {
output = output.replace(&format!("{{{{{k}}}}}"), v)
}
interpolate_variables(&mut output);
output
}
pub fn agent_prelude(&self) -> Option<&str> {
self.config.agent_prelude.as_deref()
}
pub fn variables(&self) -> &AgentVariables {
match &self.session_variables {
Some(variables) => variables,
None => &self.shared_variables,
}
}
pub fn variable_envs(&self) -> HashMap<String, String> {
self.variables()
.iter()
.map(|(k, v)| {
(
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
v.clone(),
)
})
.collect()
}
pub fn shared_variables(&self) -> &AgentVariables {
&self.shared_variables
}
pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) {
self.shared_variables = shared_variables;
}
pub fn set_session_variables(&mut self, session_variables: AgentVariables) {
self.session_variables = Some(session_variables);
}
pub fn defined_variables(&self) -> &[AgentVariable] {
&self.config.variables
}
pub fn exit_session(&mut self) {
self.session_variables = None;
self.session_dynamic_instructions = None;
}
pub fn is_dynamic_instructions(&self) -> bool {
self.config.dynamic_instructions
}
pub fn update_shared_dynamic_instructions(&mut self, force: bool) -> Result<()> {
if self.is_dynamic_instructions() && (force || self.shared_dynamic_instructions.is_none()) {
self.shared_dynamic_instructions = Some(self.run_instructions_fn()?);
}
Ok(())
}
pub fn update_session_dynamic_instructions(&mut self, value: Option<String>) -> Result<()> {
if self.is_dynamic_instructions() {
let value = match value {
Some(v) => v,
None => self.run_instructions_fn()?,
};
self.session_dynamic_instructions = Some(value);
}
Ok(())
}
fn run_instructions_fn(&self) -> Result<String> {
let value = run_llm_function(
self.name().to_string(),
vec!["_instructions".into(), "{}".into()],
self.variable_envs(),
)?;
match value {
Some(v) => Ok(v),
_ => bail!("No return value from '_instructions' function"),
}
}
}
impl RoleLike for Agent {
fn to_role(&self) -> Role {
let prompt = self.interpolated_instructions();
let mut role = Role::new("", &prompt);
role.sync(self);
role
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.config.temperature
}
fn top_p(&self) -> Option<f64> {
self.config.top_p
}
fn use_tools(&self) -> Option<String> {
self.config.global_tools.clone().join(",").into()
}
fn use_mcp_servers(&self) -> Option<String> {
self.config.mcp_servers.clone().join(",").into()
}
fn set_model(&mut self, model: Model) {
self.config.model_id = Some(model.id());
self.model = model;
}
fn set_temperature(&mut self, value: Option<f64>) {
self.config.temperature = value;
}
fn set_top_p(&mut self, value: Option<f64>) {
self.config.top_p = value;
}
fn set_use_tools(&mut self, value: Option<String>) {
match value {
Some(tools) => {
let tools = tools
.split(',')
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.collect::<Vec<_>>();
self.config.global_tools = tools;
}
None => {
self.config.global_tools.clear();
}
}
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
match value {
Some(servers) => {
let servers = servers
.split(',')
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.collect::<Vec<_>>();
self.config.mcp_servers = servers;
}
None => {
self.config.mcp_servers.clear();
}
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentConfig {
pub name: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_prelude: Option<String>,
#[serde(default)]
pub description: String,
#[serde(default)]
pub version: String,
#[serde(default)]
pub mcp_servers: Vec<String>,
#[serde(default)]
pub global_tools: Vec<String>,
#[serde(default)]
pub instructions: String,
#[serde(default)]
pub dynamic_instructions: bool,
#[serde(default)]
pub variables: Vec<AgentVariable>,
#[serde(default)]
pub conversation_starters: Vec<String>,
#[serde(default)]
pub documents: Vec<String>,
}
impl AgentConfig {
pub fn load(path: &Path) -> Result<Self> {
let contents = read_to_string(path)
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
let agent_config: Self = serde_yaml::from_str(&contents)
.with_context(|| format!("Failed to load agent config at '{}'", path.display()))?;
Ok(agent_config)
}
fn load_envs(&mut self, config: &Config) {
let name = &self.name;
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));
if self.agent_prelude.is_none() {
self.agent_prelude = config.agent_prelude.clone();
}
if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("temperature")) {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("top_p")) {
self.top_p = v;
}
if let Some(v) = read_env_value::<String>(&with_prefix("agent_prelude")) {
self.agent_prelude = v;
}
if let Ok(v) = env::var(with_prefix("variables")) {
if let Ok(v) = serde_json::from_str(&v) {
self.variables = v;
}
}
}
fn banner(&self) -> String {
let AgentConfig {
name,
description,
version,
conversation_starters,
..
} = self;
let starters = if conversation_starters.is_empty() {
String::new()
} else {
let starters = conversation_starters
.iter()
.map(|v| format!("- {v}"))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"
## Conversation Starters
{starters}"#
)
};
format!(
r#"# {name} {version}
{description}{starters}"#
)
}
fn replace_tools_placeholder(&mut self, functions: &Functions) {
let tools_placeholder: &str = "{{__tools__}}";
if self.instructions.contains(tools_placeholder) {
let tools = functions
.declarations()
.iter()
.enumerate()
.map(|(i, v)| {
let description = match v.description.split_once('\n') {
Some((v, _)) => v,
None => &v.description,
};
format!("{}. {}: {description}", i + 1, v.name)
})
.collect::<Vec<String>>()
.join("\n");
self.instructions = self.instructions.replace(tools_placeholder, &tools);
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentVariable {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_deserializing, default)]
pub value: String,
}
pub fn list_agents() -> Vec<String> {
let agents_file = Config::config_dir().join("agents.txt");
let contents = match read_to_string(agents_file) {
Ok(v) => v,
Err(_) => return vec![],
};
contents
.split('\n')
.filter_map(|line| {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
None
} else {
Some(line.to_string())
}
})
.collect()
}
pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>)> {
let config_path = Config::agent_config_file(agent_name);
if !config_path.exists() {
return vec![];
}
let Ok(config) = AgentConfig::load(&config_path) else {
return vec![];
};
config
.variables
.iter()
.map(|v| {
let description = match &v.default {
Some(default) => format!("{} [default: {default}]", v.description),
None => v.description.clone(),
};
(format!("{}=", v.name), Some(description))
})
.collect()
}
+545
View File
@@ -0,0 +1,545 @@
use super::*;
use crate::client::{
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
};
use crate::function::ToolResult;
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
use anyhow::{bail, Context, Result};
use indexmap::IndexSet;
use std::{collections::HashMap, fs::File, io::Read};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
const SUMMARY_MAX_WIDTH: usize = 80;
#[derive(Debug, Clone)]
pub struct Input {
config: GlobalConfig,
text: String,
raw: (String, Vec<String>),
patched_text: Option<String>,
last_reply: Option<String>,
continue_output: Option<String>,
regenerate: bool,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_calls: Option<MessageContentToolCalls>,
role: Role,
rag_name: Option<String>,
with_session: bool,
with_agent: bool,
}
impl Input {
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
Self {
config: config.clone(),
text: text.to_string(),
raw: (text.to_string(), vec![]),
patched_text: None,
last_reply: None,
continue_output: None,
regenerate: false,
medias: Default::default(),
data_urls: Default::default(),
tool_calls: None,
role,
rag_name: None,
with_session,
with_agent,
}
}
pub async fn from_files(
config: &GlobalConfig,
raw_text: &str,
paths: Vec<String>,
role: Option<Role>,
) -> Result<Self> {
let loaders = config.read().document_loaders.clone();
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
resolve_paths(&loaders, paths)?;
let mut last_reply = None;
let (documents, medias, data_urls) = load_documents(
&loaders,
local_paths,
remote_urls,
external_cmds,
protocol_paths,
)
.await
.context("Failed to load files")?;
let mut texts = vec![];
if !raw_text.is_empty() {
texts.push(raw_text.to_string());
};
if with_last_reply {
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
if !output.is_empty() {
last_reply = Some(output.clone())
} else if let Some(v) = input.last_reply.as_ref() {
last_reply = Some(v.clone());
}
if let Some(v) = last_reply.clone() {
texts.push(format!("\n{v}"));
}
}
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
bail!("No last reply found");
}
}
let documents_len = documents.len();
for (kind, path, contents) in documents {
if documents_len == 1 && raw_text.is_empty() {
texts.push(format!("\n{contents}"));
} else {
texts.push(format!(
"\n============ {kind}: {path} ============\n{contents}"
));
}
}
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
raw: (raw_text.to_string(), raw_paths),
patched_text: None,
last_reply,
continue_output: None,
regenerate: false,
medias,
data_urls,
tool_calls: Default::default(),
role,
rag_name: None,
with_session,
with_agent,
})
}
pub async fn from_files_with_spinner(
config: &GlobalConfig,
raw_text: &str,
paths: Vec<String>,
role: Option<Role>,
abort_signal: AbortSignal,
) -> Result<Self> {
abortable_run_with_spinner(
Input::from_files(config, raw_text, paths, role),
"Loading files",
abort_signal,
)
.await
}
pub fn is_empty(&self) -> bool {
self.text.is_empty() && self.medias.is_empty()
}
pub fn data_urls(&self) -> HashMap<String, String> {
self.data_urls.clone()
}
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
&self.tool_calls
}
pub fn text(&self) -> String {
match self.patched_text.clone() {
Some(text) => text,
None => self.text.clone(),
}
}
pub fn clear_patch(&mut self) {
self.patched_text = None;
}
pub fn set_text(&mut self, text: String) {
self.text = text;
}
pub fn stream(&self) -> bool {
self.config.read().stream && !self.role().model().no_stream()
}
pub fn continue_output(&self) -> Option<&str> {
self.continue_output.as_deref()
}
pub fn set_continue_output(&mut self, output: &str) {
let output = match &self.continue_output {
Some(v) => format!("{v}{output}"),
None => output.to_string(),
};
self.continue_output = Some(output);
}
pub fn regenerate(&self) -> bool {
self.regenerate
}
pub fn set_regenerate(&mut self) {
let role = self.config.read().extract_role();
if role.name() == self.role().name() {
self.role = role;
}
self.regenerate = true;
self.tool_calls = None;
}
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
if self.text.is_empty() {
return Ok(());
}
let rag = self.config.read().rag.clone();
if let Some(rag) = rag {
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
self.patched_text = Some(result);
self.rag_name = Some(rag.name().to_string());
}
Ok(())
}
pub fn rag_name(&self) -> Option<&str> {
self.rag_name.as_deref()
}
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
match self.tool_calls.as_mut() {
Some(exist_tool_results) => {
exist_tool_results.merge(tool_results, output);
}
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
}
self
}
pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config, Some(self.role().model().clone()))
}
pub async fn fetch_chat_text(&self) -> Result<String> {
let client = self.create_client()?;
let text = client.chat_completions(self.clone()).await?.text;
let text = strip_think_tag(&text).to_string();
Ok(text)
}
pub fn prepare_completion_data(
&self,
model: &Model,
stream: bool,
) -> Result<ChatCompletionsData> {
let mut messages = self.build_messages()?;
patch_messages(&mut messages, model);
model.guard_max_input_tokens(&messages)?;
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
let functions = self.config.read().select_functions(self.role());
if let Some(vec) = &functions {
for def in vec {
debug!("Function definition: {:?}", def.name);
}
}
Ok(ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
})
}
pub fn build_messages(&self) -> Result<Vec<Message>> {
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
session.build_messages(self)
} else {
self.role().build_messages(self)
};
if let Some(tool_calls) = &self.tool_calls {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolCalls(tool_calls.clone()),
))
}
Ok(messages)
}
pub fn echo_messages(&self) -> String {
if let Some(session) = self.session(&self.config.read().session) {
session.echo_messages(self)
} else {
self.role().echo_messages(self)
}
}
pub fn role(&self) -> &Role {
&self.role
}
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
if self.with_session {
session.as_ref()
} else {
None
}
}
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
if self.with_session {
session.as_mut()
} else {
None
}
}
pub fn with_agent(&self) -> bool {
self.with_agent
}
pub fn summary(&self) -> String {
let text: String = self
.text
.trim()
.chars()
.map(|c| if c.is_control() { ' ' } else { c })
.collect();
if text.width_cjk() > SUMMARY_MAX_WIDTH {
let mut sum_width = 0;
let mut chars = vec![];
for c in text.chars() {
sum_width += c.width_cjk().unwrap_or(1);
if sum_width > SUMMARY_MAX_WIDTH - 3 {
chars.extend(['.', '.', '.']);
break;
}
chars.push(c);
}
chars.into_iter().collect()
} else {
text
}
}
pub fn raw(&self) -> String {
let (text, files) = &self.raw;
let mut segments = files.to_vec();
if !segments.is_empty() {
segments.insert(0, ".file".into());
}
if !text.is_empty() {
if !segments.is_empty() {
segments.push("--".into());
}
segments.push(text.clone());
}
segments.join(" ")
}
pub fn render(&self) -> String {
let text = self.text();
if self.medias.is_empty() {
return text;
}
let tail_text = if text.is_empty() {
String::new()
} else {
format!(" -- {text}")
};
let files: Vec<String> = self
.medias
.iter()
.cloned()
.map(|url| resolve_data_url(&self.data_urls, url))
.collect();
format!(".file {}{}", files.join(" "), tail_text)
}
pub fn message_content(&self) -> MessageContent {
if self.medias.is_empty() {
MessageContent::Text(self.text())
} else {
let mut list: Vec<MessageContentPart> = self
.medias
.iter()
.cloned()
.map(|url| MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
})
.collect();
if !self.text.is_empty() {
list.insert(0, MessageContentPart::Text { text: self.text() });
}
MessageContent::Array(list)
}
}
}
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
match role {
Some(v) => (v, false, false),
None => (
config.extract_role(),
config.session.is_some(),
config.agent.is_some(),
),
}
}
type ResolvePathsOutput = (
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
bool,
);
fn resolve_paths(
loaders: &HashMap<String, String>,
paths: Vec<String>,
) -> Result<ResolvePathsOutput> {
let mut raw_paths = IndexSet::new();
let mut local_paths = IndexSet::new();
let mut remote_urls = IndexSet::new();
let mut external_cmds = IndexSet::new();
let mut protocol_paths = IndexSet::new();
let mut with_last_reply = false;
for path in paths {
if path == "%%" {
with_last_reply = true;
raw_paths.insert(path);
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
external_cmds.insert(path[1..path.len() - 1].to_string());
raw_paths.insert(path);
} else if is_url(&path) {
if path.strip_suffix("**").is_some() {
bail!("Invalid website '{path}'");
}
remote_urls.insert(path.clone());
raw_paths.insert(path);
} else if is_loader_protocol(loaders, &path) {
protocol_paths.insert(path.clone());
raw_paths.insert(path);
} else {
let resolved_path = resolve_home_dir(&path);
let absolute_path = to_absolute_path(&resolved_path)
.with_context(|| format!("Invalid path '{path}'"))?;
local_paths.insert(resolved_path);
raw_paths.insert(absolute_path);
}
}
Ok((
raw_paths.into_iter().collect(),
local_paths.into_iter().collect(),
remote_urls.into_iter().collect(),
external_cmds.into_iter().collect(),
protocol_paths.into_iter().collect(),
with_last_reply,
))
}
async fn load_documents(
loaders: &HashMap<String, String>,
local_paths: Vec<String>,
remote_urls: Vec<String>,
external_cmds: Vec<String>,
protocol_paths: Vec<String>,
) -> Result<(
Vec<(&'static str, String, String)>,
Vec<String>,
HashMap<String, String>,
)> {
let mut files = vec![];
let mut medias = vec![];
let mut data_urls = HashMap::new();
for cmd in external_cmds {
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
.stderr_to_stdout()
.unchecked()
.read()
.unwrap_or_else(|err| err.to_string());
files.push(("CMD", cmd, output));
}
let local_files = expand_glob_paths(&local_paths, true).await?;
for file_path in local_files {
if is_image(&file_path) {
let contents = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media '{file_path}'"))?;
data_urls.insert(sha256(&contents), file_path);
medias.push(contents)
} else {
let document = load_file(loaders, &file_path)
.await
.with_context(|| format!("Unable to read file '{file_path}'"))?;
files.push(("FILE", file_path, document.contents));
}
}
for file_url in remote_urls {
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
.await
.with_context(|| format!("Failed to load url '{file_url}'"))?;
if extension == MEDIA_URL_EXTENSION {
data_urls.insert(sha256(&contents), file_url);
medias.push(contents)
} else {
files.push(("URL", file_url, contents));
}
}
for protocol_path in protocol_paths {
let documents = load_protocol_path(loaders, &protocol_path)
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
files.extend(
documents
.into_iter()
.map(|document| ("FROM", document.path, document.contents)),
);
}
Ok((files, medias, data_urls))
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256(&data_url);
if let Some(path) = data_urls.get(&hash) {
return path.to_string();
}
data_url
} else {
data_url
}
}
fn is_image(path: &str) -> bool {
get_patch_extension(path)
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
.unwrap_or_default()
}
fn read_media_to_data_url(image_path: &str) -> Result<String> {
let extension = get_patch_extension(image_path).unwrap_or_default();
let mime_type = match extension.as_str() {
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"webp" => "image/webp",
"gif" => "image/gif",
_ => bail!("Unexpected media type"),
};
let mut file = File::open(image_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded_image = base64_encode(buffer);
let data_url = format!("data:{mime_type};base64,{encoded_image}");
Ok(data_url)
}
+3034
View File
File diff suppressed because it is too large Load Diff
+416
View File
@@ -0,0 +1,416 @@
use super::*;
use crate::client::{Message, MessageContent, MessageRole, Model};
use anyhow::Result;
use fancy_regex::Regex;
use rust_embed::Embed;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::LazyLock;
pub const SHELL_ROLE: &str = "shell";
pub const EXPLAIN_SHELL_ROLE: &str = "explain-shell";
pub const CODE_ROLE: &str = "code";
pub const CREATE_TITLE_ROLE: &str = "create-title";
pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
#[derive(Embed)]
#[folder = "assets/roles/"]
struct RolesAsset;
static RE_METADATA: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?s)-{3,}\s*(.*?)\s*-{3,}\s*(.*)").unwrap());
pub trait RoleLike {
fn to_role(&self) -> Role;
fn model(&self) -> &Model;
fn temperature(&self) -> Option<f64>;
fn top_p(&self) -> Option<f64>;
fn use_tools(&self) -> Option<String>;
fn use_mcp_servers(&self) -> Option<String>;
fn set_model(&mut self, model: Model);
fn set_temperature(&mut self, value: Option<f64>);
fn set_top_p(&mut self, value: Option<f64>);
fn set_use_tools(&mut self, value: Option<String>);
fn set_use_mcp_servers(&mut self, value: Option<String>);
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Role {
name: String,
#[serde(default)]
prompt: String,
#[serde(
rename(serialize = "model", deserialize = "model"),
skip_serializing_if = "Option::is_none"
)]
model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
use_mcp_servers: Option<String>,
#[serde(skip)]
model: Model,
}
impl Role {
pub fn new(name: &str, content: &str) -> Self {
let mut metadata = "";
let mut prompt = content.trim();
if let Ok(Some(caps)) = RE_METADATA.captures(content) {
if let (Some(metadata_value), Some(prompt_value)) = (caps.get(1), caps.get(2)) {
metadata = metadata_value.as_str().trim();
prompt = prompt_value.as_str().trim();
}
}
let mut prompt = prompt.to_string();
interpolate_variables(&mut prompt);
let mut role = Self {
name: name.to_string(),
prompt,
..Default::default()
};
if !metadata.is_empty() {
if let Ok(value) = serde_yaml::from_str::<Value>(metadata) {
if let Some(value) = value.as_object() {
for (key, value) in value {
match key.as_str() {
"model" => role.model_id = value.as_str().map(|v| v.to_string()),
"temperature" => role.temperature = value.as_f64(),
"top_p" => role.top_p = value.as_f64(),
"use_tools" => role.use_tools = value.as_str().map(|v| v.to_string()),
"use_mcp_servers" => {
role.use_mcp_servers = value.as_str().map(|v| v.to_string())
}
_ => (),
}
}
}
}
}
role
}
pub fn builtin(name: &str) -> Result<Self> {
let content = RolesAsset::get(&format!("{name}.md"))
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?;
let content = unsafe { std::str::from_utf8_unchecked(&content.data) };
Ok(Role::new(name, content))
}
pub fn list_builtin_role_names() -> Vec<String> {
RolesAsset::iter()
.filter_map(|v| v.strip_suffix(".md").map(|v| v.to_string()))
.collect()
}
pub fn list_builtin_roles() -> Vec<Self> {
RolesAsset::iter()
.filter_map(|v| Role::builtin(&v).ok())
.collect()
}
pub fn has_args(&self) -> bool {
self.name.contains('#')
}
pub fn export(&self) -> String {
let mut metadata = vec![];
if let Some(model) = self.model_id() {
metadata.push(format!("model: {model}"));
}
if let Some(temperature) = self.temperature() {
metadata.push(format!("temperature: {temperature}"));
}
if let Some(top_p) = self.top_p() {
metadata.push(format!("top_p: {top_p}"));
}
if let Some(use_tools) = self.use_tools() {
metadata.push(format!("use_tools: {use_tools}"));
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
metadata.push(format!("use_mcp_servers: {use_mcp_servers}"));
}
if metadata.is_empty() {
format!("{}\n", self.prompt)
} else if self.prompt.is_empty() {
format!("---\n{}\n---\n", metadata.join("\n"))
} else {
format!("---\n{}\n---\n\n{}\n", metadata.join("\n"), self.prompt)
}
}
pub fn save(&mut self, role_name: &str, role_path: &Path, is_repl: bool) -> Result<()> {
ensure_parent_exists(role_path)?;
let content = self.export();
std::fs::write(role_path, content).with_context(|| {
format!(
"Failed to write role {} to {}",
self.name,
role_path.display()
)
})?;
if is_repl {
println!("✓ Saved role to '{}'.", role_path.display());
}
if role_name != self.name {
self.name = role_name.to_string();
}
Ok(())
}
pub fn sync<T: RoleLike>(&mut self, role_like: &T) {
let model = role_like.model();
let temperature = role_like.temperature();
let top_p = role_like.top_p();
let use_tools = role_like.use_tools();
let use_mcp_servers = role_like.use_mcp_servers();
self.batch_set(model, temperature, top_p, use_tools, use_mcp_servers);
}
pub fn batch_set(
&mut self,
model: &Model,
temperature: Option<f64>,
top_p: Option<f64>,
use_tools: Option<String>,
use_mcp_servers: Option<String>,
) {
self.set_model(model.clone());
if temperature.is_some() {
self.set_temperature(temperature);
}
if top_p.is_some() {
self.set_top_p(top_p);
}
if use_tools.is_some() {
self.set_use_tools(use_tools);
}
if use_mcp_servers.is_some() {
self.set_use_mcp_servers(use_mcp_servers);
}
}
pub fn is_derived(&self) -> bool {
self.name.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn model_id(&self) -> Option<&str> {
self.model_id.as_deref()
}
pub fn prompt(&self) -> &str {
&self.prompt
}
pub fn is_empty_prompt(&self) -> bool {
self.prompt.is_empty()
}
pub fn is_embedded_prompt(&self) -> bool {
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn echo_messages(&self, input: &Input) -> String {
let input_markdown = input.render();
if self.is_empty_prompt() {
input_markdown
} else if self.is_embedded_prompt() {
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
} else {
format!("{}\n\n{}", self.prompt, input_markdown)
}
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut content = input.message_content();
let mut messages = if self.is_empty_prompt() {
vec![Message::new(MessageRole::User, content)]
} else if self.is_embedded_prompt() {
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
vec![Message::new(MessageRole::User, content)]
} else {
let mut messages = vec![];
let (system, cases) = parse_structure_prompt(&self.prompt);
if !system.is_empty() {
messages.push(Message::new(
MessageRole::System,
MessageContent::Text(system.to_string()),
));
}
if !cases.is_empty() {
messages.extend(cases.into_iter().flat_map(|(i, o)| {
vec![
Message::new(MessageRole::User, MessageContent::Text(i.to_string())),
Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())),
]
}));
}
messages.push(Message::new(MessageRole::User, content));
messages
};
if let Some(text) = input.continue_output() {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(text.into()),
));
}
messages
}
}
impl RoleLike for Role {
fn to_role(&self) -> Role {
self.clone()
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
fn top_p(&self) -> Option<f64> {
self.top_p
}
fn use_tools(&self) -> Option<String> {
self.use_tools.clone()
}
fn use_mcp_servers(&self) -> Option<String> {
self.use_mcp_servers.clone()
}
fn set_model(&mut self, model: Model) {
if !self.model().id().is_empty() {
self.model_id = Some(model.id().to_string());
}
self.model = model;
}
fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}
fn set_top_p(&mut self, value: Option<f64>) {
self.top_p = value;
}
fn set_use_tools(&mut self, value: Option<String>) {
self.use_tools = value;
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
self.use_mcp_servers = value;
}
}
fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
let mut text = prompt;
let mut search_input = true;
let mut system = None;
let mut parts = vec![];
loop {
let search = if search_input {
"### INPUT:"
} else {
"### OUTPUT:"
};
match text.find(search) {
Some(idx) => {
if system.is_none() {
system = Some(&text[..idx])
} else {
parts.push(&text[..idx])
}
search_input = !search_input;
text = &text[(idx + search.len())..];
}
None => {
if !text.is_empty() {
if system.is_none() {
system = Some(text)
} else {
parts.push(text)
}
}
break;
}
}
}
let parts_len = parts.len();
if parts_len > 0 && parts_len % 2 == 0 {
let cases: Vec<(&str, &str)> = parts
.iter()
.step_by(2)
.zip(parts.iter().skip(1).step_by(2))
.map(|(i, o)| (i.trim(), o.trim()))
.collect();
let system = system.map(|v| v.trim()).unwrap_or_default();
return (system, cases);
}
(prompt, vec![])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_structure_prompt1() {
let prompt = r#"
System message
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("System message", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt2() {
let prompt = r#"
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt3() {
let prompt = r#"
System message
### INPUT:
Input 1
"#;
assert_eq!(parse_structure_prompt(prompt), (prompt, vec![]));
}
}
+659
View File
@@ -0,0 +1,659 @@
use super::input::*;
use super::*;
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use inquire::{validator::Validation, Confirm, Text};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{read_to_string, write};
use std::path::Path;
use std::sync::LazyLock;
static RE_AUTONAME_PREFIX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap());
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Session {
#[serde(rename(serialize = "model", deserialize = "model"))]
model_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
use_mcp_servers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
save_session: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
compress_threshold: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
role_name: Option<String>,
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
agent_variables: AgentVariables,
#[serde(default, skip_serializing_if = "String::is_empty")]
agent_instructions: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
compressed_messages: Vec<Message>,
messages: Vec<Message>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
data_urls: HashMap<String, String>,
#[serde(skip)]
model: Model,
#[serde(skip)]
role_prompt: String,
#[serde(skip)]
name: String,
#[serde(skip)]
path: Option<String>,
#[serde(skip)]
dirty: bool,
#[serde(skip)]
save_session_this_time: bool,
#[serde(skip)]
compressing: bool,
#[serde(skip)]
autoname: Option<AutoName>,
#[serde(skip)]
tokens: usize,
}
impl Session {
pub fn new(config: &Config, name: &str) -> Self {
let role = config.extract_role();
let mut session = Self {
name: name.to_string(),
save_session: config.save_session,
..Default::default()
};
session.set_role(role);
session.dirty = false;
session
}
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
let content = read_to_string(path)
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
let mut session: Self =
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?;
session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?;
if let Some(autoname) = name.strip_prefix("_/") {
session.name = TEMP_SESSION_NAME.to_string();
session.path = None;
if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) {
session.autoname = Some(AutoName::new(autoname[16..].to_string()));
}
} else {
session.name = name.to_string();
session.path = Some(path.display().to_string());
}
if let Some(role_name) = &session.role_name {
if let Ok(role) = config.retrieve_role(role_name) {
session.role_prompt = role.prompt().to_string();
}
}
session.update_tokens();
Ok(session)
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty() && self.compressed_messages.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn role_name(&self) -> Option<&str> {
self.role_name.as_deref()
}
pub fn dirty(&self) -> bool {
self.dirty
}
pub fn save_session(&self) -> Option<bool> {
self.save_session
}
pub fn tokens(&self) -> usize {
self.tokens
}
pub fn update_tokens(&mut self) {
self.tokens = self.model().total_tokens(&self.messages);
}
pub fn has_user_messages(&self) -> bool {
self.messages.iter().any(|v| v.role.is_user())
}
pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
pub fn export(&self) -> Result<String> {
let mut data = json!({
"path": self.path,
"model": self.model().id(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into();
}
if let Some(use_tools) = self.use_tools() {
data["use_tools"] = use_tools.into();
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
data["use_mcp_servers"] = use_mcp_servers.into();
}
if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into();
}
let (tokens, percent) = self.tokens_usage();
data["total_tokens"] = tokens.into();
if let Some(max_input_tokens) = self.model().max_input_tokens() {
data["max_input_tokens"] = max_input_tokens.into();
}
if percent != 0.0 {
data["total/max"] = format!("{percent}%").into();
}
data["messages"] = json!(self.messages);
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about session '{}'", &self.name))?;
Ok(output)
}
pub fn render(
&self,
render: &mut MarkdownRender,
agent_info: &Option<(String, Vec<String>)>,
) -> Result<String> {
let mut items = vec![];
if let Some(path) = &self.path {
items.push(("path", path.to_string()));
}
if let Some(autoname) = self.autoname() {
items.push(("autoname", autoname.to_string()));
}
items.push(("model", self.model().id()));
if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
}
if let Some(top_p) = self.top_p() {
items.push(("top_p", top_p.to_string()));
}
if let Some(use_tools) = self.use_tools() {
items.push(("use_tools", use_tools));
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
items.push(("use_mcp_servers", use_mcp_servers));
}
if let Some(save_session) = self.save_session() {
items.push(("save_session", save_session.to_string()));
}
if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_input_tokens) = self.model().max_input_tokens() {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
let mut lines: Vec<String> = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect();
lines.push(String::new());
if !self.is_empty() {
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
for message in &self.messages {
match message.role {
MessageRole::System => {
lines.push(
render
.render(&message.content.render_input(resolve_url_fn, agent_info)),
);
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
lines.push(render.render(text));
}
lines.push("".into());
}
MessageRole::User => {
lines.push(format!(
">> {}",
message.content.render_input(resolve_url_fn, agent_info)
));
}
MessageRole::Tool => {
lines.push(message.content.render_input(resolve_url_fn, agent_info));
}
}
}
}
Ok(lines.join("\n"))
}
pub fn tokens_usage(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_input_tokens = self.model().max_input_tokens().unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)
}
pub fn set_role(&mut self, role: Role) {
self.model_id = role.model().id();
self.temperature = role.temperature();
self.top_p = role.top_p();
self.use_tools = role.use_tools();
self.use_mcp_servers = role.use_mcp_servers();
self.model = role.model().clone();
self.role_name = convert_option_string(role.name());
self.role_prompt = role.prompt().to_string();
self.dirty = true;
self.update_tokens();
}
pub fn clear_role(&mut self) {
self.role_name = None;
self.role_prompt.clear();
}
pub fn sync_agent(&mut self, agent: &Agent) {
self.role_name = None;
self.role_prompt = agent.interpolated_instructions();
self.agent_variables = agent.variables().clone();
self.agent_instructions = self.role_prompt.clone();
}
pub fn agent_variables(&self) -> &AgentVariables {
&self.agent_variables
}
pub fn agent_instructions(&self) -> &str {
&self.agent_instructions
}
pub fn set_save_session(&mut self, value: Option<bool>) {
if self.save_session != value {
self.save_session = value;
self.dirty = true;
}
}
pub fn set_save_session_this_time(&mut self) {
self.save_session_this_time = true;
}
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
if self.compress_threshold != value {
self.compress_threshold = value;
self.dirty = true;
}
}
pub fn need_compress(&self, global_compress_threshold: usize) -> bool {
if self.compressing {
return false;
}
let threshold = self.compress_threshold.unwrap_or(global_compress_threshold);
if threshold < 1 {
return false;
}
self.tokens() > threshold
}
pub fn compressing(&self) -> bool {
self.compressing
}
pub fn set_compressing(&mut self, compressing: bool) {
self.compressing = compressing;
}
pub fn compress(&mut self, mut prompt: String) {
if let Some(system_prompt) = self.messages.first().and_then(|v| {
if MessageRole::System == v.role {
let content = v.content.to_text();
if !content.is_empty() {
return Some(content);
}
}
None
}) {
prompt = format!("{system_prompt}\n\n{prompt}",);
}
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message::new(
MessageRole::System,
MessageContent::Text(prompt),
));
self.dirty = true;
self.update_tokens();
}
pub fn need_autoname(&self) -> bool {
self.autoname.as_ref().map(|v| v.need()).unwrap_or_default()
}
pub fn set_autonaming(&mut self, naming: bool) {
if let Some(v) = self.autoname.as_mut() {
v.naming = naming;
}
}
pub fn chat_history_for_autonaming(&self) -> Option<String> {
self.autoname.as_ref().and_then(|v| v.chat_history.clone())
}
pub fn autoname(&self) -> Option<&str> {
self.autoname.as_ref().and_then(|v| v.name.as_deref())
}
pub fn set_autoname(&mut self, value: &str) {
let name = value
.chars()
.map(|v| if v.is_alphanumeric() { v } else { '-' })
.collect();
self.autoname = Some(AutoName::new(name));
}
pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> {
let mut save_session = self.save_session();
if self.save_session_this_time {
save_session = Some(true);
}
if self.dirty && save_session != Some(false) {
let mut session_dir = session_dir.to_path_buf();
let mut session_name = self.name().to_string();
if save_session.is_none() {
if !is_repl {
return Ok(());
}
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
if !ans {
return Ok(());
}
if session_name == TEMP_SESSION_NAME {
session_name = Text::new("Session name:")
.with_validator(|input: &str| {
let input = input.trim();
if input.is_empty() {
Ok(Validation::Invalid("This name is required".into()))
} else if input == TEMP_SESSION_NAME {
Ok(Validation::Invalid("This name is reserved".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
} else if save_session == Some(true) && session_name == TEMP_SESSION_NAME {
session_dir = session_dir.join("_");
ensure_parent_exists(&session_dir).with_context(|| {
format!("Failed to create directory '{}'", session_dir.display())
})?;
let now = chrono::Local::now();
session_name = now.format("%Y%m%dT%H%M%S").to_string();
if let Some(autoname) = self.autoname() {
session_name = format!("{session_name}-{autoname}")
}
}
let session_path = session_dir.join(format!("{session_name}.yaml"));
self.save(&session_name, &session_path, is_repl)?;
}
Ok(())
}
pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> {
ensure_parent_exists(session_path)?;
self.path = Some(session_path.display().to_string());
let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session '{}'", self.name))?;
write(session_path, content).with_context(|| {
format!(
"Failed to write session '{}' to '{}'",
self.name,
session_path.display()
)
})?;
if is_repl {
println!("✓ Saved the session to '{}'.", session_path.display());
}
if self.name() != session_name {
self.name = session_name.to_string()
}
self.dirty = false;
Ok(())
}
pub fn guard_empty(&self) -> Result<()> {
if !self.is_empty() {
bail!("Cannot perform this operation because the session has messages, please `.empty session` first.");
}
Ok(())
}
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
if input.continue_output().is_some() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
}
} else if input.regenerate() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = output.to_string();
}
}
} else {
if self.messages.is_empty() {
if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) {
let raw_input = input.raw();
let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n");
self.autoname = Some(AutoName::new_from_chat_history(chat_history));
}
self.messages.extend(input.role().build_messages(input));
} else {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
if let Some(tool_calls) = input.tool_calls() {
self.messages.push(Message::new(
MessageRole::Tool,
MessageContent::ToolCalls(tool_calls.clone()),
))
}
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
self.dirty = true;
self.update_tokens();
Ok(())
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.autoname = None;
self.dirty = true;
self.update_tokens();
}
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_messages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
if input.continue_output().is_some() {
return messages;
} else if input.regenerate() {
while let Some(last) = messages.last() {
if !last.role.is_user() {
messages.pop();
} else {
break;
}
}
return messages;
}
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
messages = input.role().build_messages(input);
need_add_msg = false;
} else if len == 1 && self.compressed_messages.len() >= 2 {
if let Some(index) = self
.compressed_messages
.iter()
.rposition(|v| v.role == MessageRole::User)
{
messages.extend(self.compressed_messages[index..].to_vec());
}
}
if need_add_msg {
messages.push(Message::new(MessageRole::User, input.message_content()));
}
messages
}
}
impl RoleLike for Session {
fn to_role(&self) -> Role {
let role_name = self.role_name.as_deref().unwrap_or_default();
let mut role = Role::new(role_name, &self.role_prompt);
role.sync(self);
role
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
fn top_p(&self) -> Option<f64> {
self.top_p
}
fn use_tools(&self) -> Option<String> {
self.use_tools.clone()
}
fn use_mcp_servers(&self) -> Option<String> {
self.use_mcp_servers.clone()
}
fn set_model(&mut self, model: Model) {
if self.model().id() != model.id() {
self.model_id = model.id();
self.model = model;
self.dirty = true;
self.update_tokens();
}
}
fn set_temperature(&mut self, value: Option<f64>) {
if self.temperature != value {
self.temperature = value;
self.dirty = true;
}
}
fn set_top_p(&mut self, value: Option<f64>) {
if self.top_p != value {
self.top_p = value;
self.dirty = true;
}
}
fn set_use_tools(&mut self, value: Option<String>) {
if self.use_tools != value {
self.use_tools = value;
self.dirty = true;
}
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
if self.use_mcp_servers != value {
self.use_mcp_servers = value;
self.dirty = true;
}
}
}
#[derive(Debug, Clone, Default)]
struct AutoName {
naming: bool,
chat_history: Option<String>,
name: Option<String>,
}
impl AutoName {
pub fn new(name: String) -> Self {
Self {
name: Some(name),
..Default::default()
}
}
pub fn new_from_chat_history(chat_history: String) -> Self {
Self {
chat_history: Some(chat_history),
..Default::default()
}
}
pub fn need(&self) -> bool {
!self.naming && self.chat_history.is_some() && self.name.is_none()
}
}