Baseline project
This commit is contained in:
@@ -0,0 +1,570 @@
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
client::Model,
|
||||
function::{run_llm_function, Functions},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use inquire::{validator::Validation, Text};
|
||||
use std::{fs::read_to_string, path::Path};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_AGENT_NAME: &str = "rag";
|
||||
|
||||
pub type AgentVariables = IndexMap<String, String>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Agent {
|
||||
name: String,
|
||||
config: AgentConfig,
|
||||
shared_variables: AgentVariables,
|
||||
session_variables: Option<AgentVariables>,
|
||||
shared_dynamic_instructions: Option<String>,
|
||||
session_dynamic_instructions: Option<String>,
|
||||
functions: Functions,
|
||||
rag: Option<Arc<Rag>>,
|
||||
model: Model,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub async fn init(
|
||||
config: &GlobalConfig,
|
||||
name: &str,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
let agent_data_dir = Config::agent_data_dir(name);
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
||||
let config_path = Config::agent_config_file(name);
|
||||
let mut agent_config = if config_path.exists() {
|
||||
AgentConfig::load(&config_path)?
|
||||
} else {
|
||||
bail!("Agent config file not found at '{}'", config_path.display())
|
||||
};
|
||||
let mut functions = Functions::init_agent(name, &agent_config.global_tools)?;
|
||||
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
let mcp_servers =
|
||||
(!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(","));
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.expect("MCP registry should be initialized");
|
||||
let new_mcp_registry =
|
||||
McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?;
|
||||
|
||||
if !new_mcp_registry.is_empty() {
|
||||
functions.append_mcp_meta_functions(new_mcp_registry.list_servers());
|
||||
}
|
||||
|
||||
config.write().mcp_registry = Some(new_mcp_registry);
|
||||
agent_config.replace_tools_placeholder(&functions);
|
||||
|
||||
agent_config.load_envs(&config.read());
|
||||
|
||||
let model = {
|
||||
let config = config.read();
|
||||
match agent_config.model_id.as_ref() {
|
||||
Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?,
|
||||
None => {
|
||||
if agent_config.temperature.is_none() {
|
||||
agent_config.temperature = config.temperature;
|
||||
}
|
||||
if agent_config.top_p.is_none() {
|
||||
agent_config.top_p = config.top_p;
|
||||
}
|
||||
config.current_model().clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let rag = if rag_path.exists() {
|
||||
Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?))
|
||||
} else if !agent_config.documents.is_empty() && !config.read().info_flag {
|
||||
let mut ans = false;
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
ans = Confirm::new("The agent has documents attached, init RAG?")
|
||||
.with_default(true)
|
||||
.prompt()?;
|
||||
}
|
||||
if ans {
|
||||
let mut document_paths = vec![];
|
||||
for path in &agent_config.documents {
|
||||
if is_url(path) {
|
||||
document_paths.push(path.to_string());
|
||||
} else if is_loader_protocol(&loaders, path) {
|
||||
let (protocol, document_path) = path
|
||||
.split_once(':')
|
||||
.with_context(|| "Invalid loader protocol path")?;
|
||||
let resolved_path = resolve_home_dir(document_path);
|
||||
let new_path = if Path::new(&resolved_path).is_relative() {
|
||||
safe_join_path(&agent_data_dir, resolved_path)
|
||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
|
||||
} else {
|
||||
PathBuf::from(&resolved_path)
|
||||
};
|
||||
document_paths.push(format!("{}:{}", protocol, new_path.display()));
|
||||
} else if Path::new(&resolve_home_dir(path)).is_relative() {
|
||||
let new_path = safe_join_path(&agent_data_dir, path)
|
||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
|
||||
document_paths.push(new_path.display().to_string())
|
||||
} else {
|
||||
document_paths.push(path.to_string())
|
||||
}
|
||||
}
|
||||
let rag =
|
||||
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
|
||||
Some(Arc::new(rag))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
config: agent_config,
|
||||
shared_variables: Default::default(),
|
||||
session_variables: None,
|
||||
shared_dynamic_instructions: None,
|
||||
session_dynamic_instructions: None,
|
||||
functions,
|
||||
rag,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn init_agent_variables(
|
||||
agent_variables: &[AgentVariable],
|
||||
no_interaction: bool,
|
||||
) -> Result<AgentVariables> {
|
||||
let mut output = IndexMap::new();
|
||||
if agent_variables.is_empty() {
|
||||
return Ok(output);
|
||||
}
|
||||
let mut printed = false;
|
||||
let mut unset_variables = vec![];
|
||||
for agent_variable in agent_variables {
|
||||
let key = agent_variable.name.clone();
|
||||
if let Some(value) = agent_variable.default.clone() {
|
||||
output.insert(key, value);
|
||||
continue;
|
||||
}
|
||||
if no_interaction {
|
||||
continue;
|
||||
}
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
if !printed {
|
||||
println!("⚙ Init agent variables...");
|
||||
printed = true;
|
||||
}
|
||||
let value = Text::new(&format!(
|
||||
"{} ({}):",
|
||||
agent_variable.name, agent_variable.description
|
||||
))
|
||||
.with_validator(|input: &str| {
|
||||
if input.trim().is_empty() {
|
||||
Ok(Validation::Invalid("This field is required".into()))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
output.insert(key, value);
|
||||
} else {
|
||||
unset_variables.push(agent_variable)
|
||||
}
|
||||
}
|
||||
if !unset_variables.is_empty() {
|
||||
bail!(
|
||||
"The following agent variables are required:\n{}",
|
||||
unset_variables
|
||||
.iter()
|
||||
.map(|v| format!(" - {}: {}", v.name, v.description))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn export(&self) -> Result<String> {
|
||||
let mut value = json!({});
|
||||
value["name"] = json!(self.name());
|
||||
let variables = self.variables();
|
||||
if !variables.is_empty() {
|
||||
value["variables"] = serde_json::to_value(variables)?;
|
||||
}
|
||||
value["config"] = json!(self.config);
|
||||
let mut config = self.config.clone();
|
||||
config.instructions = self.interpolated_instructions();
|
||||
value["definition"] = json!(config);
|
||||
value["data_dir"] = Config::agent_data_dir(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
value["config_file"] = Config::agent_config_file(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
let data = serde_yaml::to_string(&value)?;
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
pub fn banner(&self) -> String {
|
||||
self.config.banner()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn functions(&self) -> &Functions {
|
||||
&self.functions
|
||||
}
|
||||
|
||||
pub fn rag(&self) -> Option<Arc<Rag>> {
|
||||
self.rag.clone()
|
||||
}
|
||||
|
||||
pub fn conversation_starters(&self) -> &[String] {
|
||||
&self.config.conversation_starters
|
||||
}
|
||||
|
||||
pub fn interpolated_instructions(&self) -> String {
|
||||
let mut output = self
|
||||
.session_dynamic_instructions
|
||||
.clone()
|
||||
.or_else(|| self.shared_dynamic_instructions.clone())
|
||||
.unwrap_or_else(|| self.config.instructions.clone());
|
||||
for (k, v) in self.variables() {
|
||||
output = output.replace(&format!("{{{{{k}}}}}"), v)
|
||||
}
|
||||
interpolate_variables(&mut output);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn agent_prelude(&self) -> Option<&str> {
|
||||
self.config.agent_prelude.as_deref()
|
||||
}
|
||||
|
||||
pub fn variables(&self) -> &AgentVariables {
|
||||
match &self.session_variables {
|
||||
Some(variables) => variables,
|
||||
None => &self.shared_variables,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn variable_envs(&self) -> HashMap<String, String> {
|
||||
self.variables()
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
|
||||
v.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn shared_variables(&self) -> &AgentVariables {
|
||||
&self.shared_variables
|
||||
}
|
||||
|
||||
pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) {
|
||||
self.shared_variables = shared_variables;
|
||||
}
|
||||
|
||||
pub fn set_session_variables(&mut self, session_variables: AgentVariables) {
|
||||
self.session_variables = Some(session_variables);
|
||||
}
|
||||
|
||||
pub fn defined_variables(&self) -> &[AgentVariable] {
|
||||
&self.config.variables
|
||||
}
|
||||
|
||||
pub fn exit_session(&mut self) {
|
||||
self.session_variables = None;
|
||||
self.session_dynamic_instructions = None;
|
||||
}
|
||||
|
||||
pub fn is_dynamic_instructions(&self) -> bool {
|
||||
self.config.dynamic_instructions
|
||||
}
|
||||
|
||||
pub fn update_shared_dynamic_instructions(&mut self, force: bool) -> Result<()> {
|
||||
if self.is_dynamic_instructions() && (force || self.shared_dynamic_instructions.is_none()) {
|
||||
self.shared_dynamic_instructions = Some(self.run_instructions_fn()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_session_dynamic_instructions(&mut self, value: Option<String>) -> Result<()> {
|
||||
if self.is_dynamic_instructions() {
|
||||
let value = match value {
|
||||
Some(v) => v,
|
||||
None => self.run_instructions_fn()?,
|
||||
};
|
||||
self.session_dynamic_instructions = Some(value);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instructions_fn(&self) -> Result<String> {
|
||||
let value = run_llm_function(
|
||||
self.name().to_string(),
|
||||
vec!["_instructions".into(), "{}".into()],
|
||||
self.variable_envs(),
|
||||
)?;
|
||||
match value {
|
||||
Some(v) => Ok(v),
|
||||
_ => bail!("No return value from '_instructions' function"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Agent {
|
||||
fn to_role(&self) -> Role {
|
||||
let prompt = self.interpolated_instructions();
|
||||
let mut role = Role::new("", &prompt);
|
||||
role.sync(self);
|
||||
role
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.config.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.config.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.config.global_tools.clone().join(",").into()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.config.mcp_servers.clone().join(",").into()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
self.config.model_id = Some(model.id());
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
self.config.temperature = value;
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
self.config.top_p = value;
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
match value {
|
||||
Some(tools) => {
|
||||
let tools = tools
|
||||
.split(',')
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
self.config.global_tools = tools;
|
||||
}
|
||||
None => {
|
||||
self.config.global_tools.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
match value {
|
||||
Some(servers) => {
|
||||
let servers = servers
|
||||
.split(',')
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
self.config.mcp_servers = servers;
|
||||
}
|
||||
None => {
|
||||
self.config.mcp_servers.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct AgentConfig {
|
||||
pub name: String,
|
||||
#[serde(rename(serialize = "model", deserialize = "model"))]
|
||||
pub model_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub agent_prelude: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
#[serde(default)]
|
||||
pub version: String,
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub global_tools: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub instructions: String,
|
||||
#[serde(default)]
|
||||
pub dynamic_instructions: bool,
|
||||
#[serde(default)]
|
||||
pub variables: Vec<AgentVariable>,
|
||||
#[serde(default)]
|
||||
pub conversation_starters: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub documents: Vec<String>,
|
||||
}
|
||||
|
||||
impl AgentConfig {
|
||||
pub fn load(path: &Path) -> Result<Self> {
|
||||
let contents = read_to_string(path)
|
||||
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
|
||||
let agent_config: Self = serde_yaml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to load agent config at '{}'", path.display()))?;
|
||||
|
||||
Ok(agent_config)
|
||||
}
|
||||
|
||||
fn load_envs(&mut self, config: &Config) {
|
||||
let name = &self.name;
|
||||
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));
|
||||
|
||||
if self.agent_prelude.is_none() {
|
||||
self.agent_prelude = config.agent_prelude.clone();
|
||||
}
|
||||
|
||||
if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
|
||||
self.model_id = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<f64>(&with_prefix("temperature")) {
|
||||
self.temperature = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<f64>(&with_prefix("top_p")) {
|
||||
self.top_p = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<String>(&with_prefix("agent_prelude")) {
|
||||
self.agent_prelude = v;
|
||||
}
|
||||
if let Ok(v) = env::var(with_prefix("variables")) {
|
||||
if let Ok(v) = serde_json::from_str(&v) {
|
||||
self.variables = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn banner(&self) -> String {
|
||||
let AgentConfig {
|
||||
name,
|
||||
description,
|
||||
version,
|
||||
conversation_starters,
|
||||
..
|
||||
} = self;
|
||||
let starters = if conversation_starters.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let starters = conversation_starters
|
||||
.iter()
|
||||
.map(|v| format!("- {v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
format!(
|
||||
r#"
|
||||
|
||||
## Conversation Starters
|
||||
{starters}"#
|
||||
)
|
||||
};
|
||||
format!(
|
||||
r#"# {name} {version}
|
||||
{description}{starters}"#
|
||||
)
|
||||
}
|
||||
|
||||
fn replace_tools_placeholder(&mut self, functions: &Functions) {
|
||||
let tools_placeholder: &str = "{{__tools__}}";
|
||||
if self.instructions.contains(tools_placeholder) {
|
||||
let tools = functions
|
||||
.declarations()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
let description = match v.description.split_once('\n') {
|
||||
Some((v, _)) => v,
|
||||
None => &v.description,
|
||||
};
|
||||
format!("{}. {}: {description}", i + 1, v.name)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
self.instructions = self.instructions.replace(tools_placeholder, &tools);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct AgentVariable {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
pub fn list_agents() -> Vec<String> {
|
||||
let agents_file = Config::config_dir().join("agents.txt");
|
||||
let contents = match read_to_string(agents_file) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return vec![],
|
||||
};
|
||||
contents
|
||||
.split('\n')
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
None
|
||||
} else {
|
||||
Some(line.to_string())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>)> {
|
||||
let config_path = Config::agent_config_file(agent_name);
|
||||
if !config_path.exists() {
|
||||
return vec![];
|
||||
}
|
||||
let Ok(config) = AgentConfig::load(&config_path) else {
|
||||
return vec![];
|
||||
};
|
||||
config
|
||||
.variables
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let description = match &v.default {
|
||||
Some(default) => format!("{} [default: {default}]", v.description),
|
||||
None => v.description.clone(),
|
||||
};
|
||||
(format!("{}=", v.name), Some(description))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
use super::*;
|
||||
|
||||
use crate::client::{
|
||||
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
|
||||
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
|
||||
};
|
||||
use crate::function::ToolResult;
|
||||
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use indexmap::IndexSet;
|
||||
use std::{collections::HashMap, fs::File, io::Read};
|
||||
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
||||
|
||||
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
||||
const SUMMARY_MAX_WIDTH: usize = 80;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Input {
|
||||
config: GlobalConfig,
|
||||
text: String,
|
||||
raw: (String, Vec<String>),
|
||||
patched_text: Option<String>,
|
||||
last_reply: Option<String>,
|
||||
continue_output: Option<String>,
|
||||
regenerate: bool,
|
||||
medias: Vec<String>,
|
||||
data_urls: HashMap<String, String>,
|
||||
tool_calls: Option<MessageContentToolCalls>,
|
||||
role: Role,
|
||||
rag_name: Option<String>,
|
||||
with_session: bool,
|
||||
with_agent: bool,
|
||||
}
|
||||
|
||||
impl Input {
|
||||
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Self {
|
||||
config: config.clone(),
|
||||
text: text.to_string(),
|
||||
raw: (text.to_string(), vec![]),
|
||||
patched_text: None,
|
||||
last_reply: None,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias: Default::default(),
|
||||
data_urls: Default::default(),
|
||||
tool_calls: None,
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_files(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
) -> Result<Self> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
||||
resolve_paths(&loaders, paths)?;
|
||||
let mut last_reply = None;
|
||||
let (documents, medias, data_urls) = load_documents(
|
||||
&loaders,
|
||||
local_paths,
|
||||
remote_urls,
|
||||
external_cmds,
|
||||
protocol_paths,
|
||||
)
|
||||
.await
|
||||
.context("Failed to load files")?;
|
||||
let mut texts = vec![];
|
||||
if !raw_text.is_empty() {
|
||||
texts.push(raw_text.to_string());
|
||||
};
|
||||
if with_last_reply {
|
||||
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
||||
if !output.is_empty() {
|
||||
last_reply = Some(output.clone())
|
||||
} else if let Some(v) = input.last_reply.as_ref() {
|
||||
last_reply = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = last_reply.clone() {
|
||||
texts.push(format!("\n{v}"));
|
||||
}
|
||||
}
|
||||
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
|
||||
bail!("No last reply found");
|
||||
}
|
||||
}
|
||||
let documents_len = documents.len();
|
||||
for (kind, path, contents) in documents {
|
||||
if documents_len == 1 && raw_text.is_empty() {
|
||||
texts.push(format!("\n{contents}"));
|
||||
} else {
|
||||
texts.push(format!(
|
||||
"\n============ {kind}: {path} ============\n{contents}"
|
||||
));
|
||||
}
|
||||
}
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
text: texts.join("\n"),
|
||||
raw: (raw_text.to_string(), raw_paths),
|
||||
patched_text: None,
|
||||
last_reply,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias,
|
||||
data_urls,
|
||||
tool_calls: Default::default(),
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_files_with_spinner(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
abortable_run_with_spinner(
|
||||
Input::from_files(config, raw_text, paths, role),
|
||||
"Loading files",
|
||||
abort_signal,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.text.is_empty() && self.medias.is_empty()
|
||||
}
|
||||
|
||||
pub fn data_urls(&self) -> HashMap<String, String> {
|
||||
self.data_urls.clone()
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
pub fn text(&self) -> String {
|
||||
match self.patched_text.clone() {
|
||||
Some(text) => text,
|
||||
None => self.text.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_patch(&mut self) {
|
||||
self.patched_text = None;
|
||||
}
|
||||
|
||||
pub fn set_text(&mut self, text: String) {
|
||||
self.text = text;
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> bool {
|
||||
self.config.read().stream && !self.role().model().no_stream()
|
||||
}
|
||||
|
||||
pub fn continue_output(&self) -> Option<&str> {
|
||||
self.continue_output.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_continue_output(&mut self, output: &str) {
|
||||
let output = match &self.continue_output {
|
||||
Some(v) => format!("{v}{output}"),
|
||||
None => output.to_string(),
|
||||
};
|
||||
self.continue_output = Some(output);
|
||||
}
|
||||
|
||||
pub fn regenerate(&self) -> bool {
|
||||
self.regenerate
|
||||
}
|
||||
|
||||
pub fn set_regenerate(&mut self) {
|
||||
let role = self.config.read().extract_role();
|
||||
if role.name() == self.role().name() {
|
||||
self.role = role;
|
||||
}
|
||||
self.regenerate = true;
|
||||
self.tool_calls = None;
|
||||
}
|
||||
|
||||
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
|
||||
if self.text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rag = self.config.read().rag.clone();
|
||||
if let Some(rag) = rag {
|
||||
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
||||
self.patched_text = Some(result);
|
||||
self.rag_name = Some(rag.name().to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rag_name(&self) -> Option<&str> {
|
||||
self.rag_name.as_deref()
|
||||
}
|
||||
|
||||
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
|
||||
match self.tool_calls.as_mut() {
|
||||
Some(exist_tool_results) => {
|
||||
exist_tool_results.merge(tool_results, output);
|
||||
}
|
||||
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.config, Some(self.role().model().clone()))
|
||||
}
|
||||
|
||||
pub async fn fetch_chat_text(&self) -> Result<String> {
|
||||
let client = self.create_client()?;
|
||||
let text = client.chat_completions(self.clone()).await?.text;
|
||||
let text = strip_think_tag(&text).to_string();
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
pub fn prepare_completion_data(
|
||||
&self,
|
||||
model: &Model,
|
||||
stream: bool,
|
||||
) -> Result<ChatCompletionsData> {
|
||||
let mut messages = self.build_messages()?;
|
||||
patch_messages(&mut messages, model);
|
||||
model.guard_max_input_tokens(&messages)?;
|
||||
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
||||
let functions = self.config.read().select_functions(self.role());
|
||||
if let Some(vec) = &functions {
|
||||
for def in vec {
|
||||
debug!("Function definition: {:?}", def.name);
|
||||
}
|
||||
}
|
||||
Ok(ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
||||
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.build_messages(self)
|
||||
} else {
|
||||
self.role().build_messages(self)
|
||||
};
|
||||
if let Some(tool_calls) = &self.tool_calls {
|
||||
messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::ToolCalls(tool_calls.clone()),
|
||||
))
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self) -> String {
|
||||
if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.echo_messages(self)
|
||||
} else {
|
||||
self.role().echo_messages(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
|
||||
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
|
||||
if self.with_session {
|
||||
session.as_ref()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
|
||||
if self.with_session {
|
||||
session.as_mut()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_agent(&self) -> bool {
|
||||
self.with_agent
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> String {
|
||||
let text: String = self
|
||||
.text
|
||||
.trim()
|
||||
.chars()
|
||||
.map(|c| if c.is_control() { ' ' } else { c })
|
||||
.collect();
|
||||
if text.width_cjk() > SUMMARY_MAX_WIDTH {
|
||||
let mut sum_width = 0;
|
||||
let mut chars = vec![];
|
||||
for c in text.chars() {
|
||||
sum_width += c.width_cjk().unwrap_or(1);
|
||||
if sum_width > SUMMARY_MAX_WIDTH - 3 {
|
||||
chars.extend(['.', '.', '.']);
|
||||
break;
|
||||
}
|
||||
chars.push(c);
|
||||
}
|
||||
chars.into_iter().collect()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> String {
|
||||
let (text, files) = &self.raw;
|
||||
let mut segments = files.to_vec();
|
||||
if !segments.is_empty() {
|
||||
segments.insert(0, ".file".into());
|
||||
}
|
||||
if !text.is_empty() {
|
||||
if !segments.is_empty() {
|
||||
segments.push("--".into());
|
||||
}
|
||||
segments.push(text.clone());
|
||||
}
|
||||
segments.join(" ")
|
||||
}
|
||||
|
||||
pub fn render(&self) -> String {
|
||||
let text = self.text();
|
||||
if self.medias.is_empty() {
|
||||
return text;
|
||||
}
|
||||
let tail_text = if text.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" -- {text}")
|
||||
};
|
||||
let files: Vec<String> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| resolve_data_url(&self.data_urls, url))
|
||||
.collect();
|
||||
format!(".file {}{}", files.join(" "), tail_text)
|
||||
}
|
||||
|
||||
pub fn message_content(&self) -> MessageContent {
|
||||
if self.medias.is_empty() {
|
||||
MessageContent::Text(self.text())
|
||||
} else {
|
||||
let mut list: Vec<MessageContentPart> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
})
|
||||
.collect();
|
||||
if !self.text.is_empty() {
|
||||
list.insert(0, MessageContentPart::Text { text: self.text() });
|
||||
}
|
||||
MessageContent::Array(list)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
||||
match role {
|
||||
Some(v) => (v, false, false),
|
||||
None => (
|
||||
config.extract_role(),
|
||||
config.session.is_some(),
|
||||
config.agent.is_some(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type ResolvePathsOutput = (
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
bool,
|
||||
);
|
||||
|
||||
fn resolve_paths(
|
||||
loaders: &HashMap<String, String>,
|
||||
paths: Vec<String>,
|
||||
) -> Result<ResolvePathsOutput> {
|
||||
let mut raw_paths = IndexSet::new();
|
||||
let mut local_paths = IndexSet::new();
|
||||
let mut remote_urls = IndexSet::new();
|
||||
let mut external_cmds = IndexSet::new();
|
||||
let mut protocol_paths = IndexSet::new();
|
||||
let mut with_last_reply = false;
|
||||
for path in paths {
|
||||
if path == "%%" {
|
||||
with_last_reply = true;
|
||||
raw_paths.insert(path);
|
||||
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
|
||||
external_cmds.insert(path[1..path.len() - 1].to_string());
|
||||
raw_paths.insert(path);
|
||||
} else if is_url(&path) {
|
||||
if path.strip_suffix("**").is_some() {
|
||||
bail!("Invalid website '{path}'");
|
||||
}
|
||||
remote_urls.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else if is_loader_protocol(loaders, &path) {
|
||||
protocol_paths.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else {
|
||||
let resolved_path = resolve_home_dir(&path);
|
||||
let absolute_path = to_absolute_path(&resolved_path)
|
||||
.with_context(|| format!("Invalid path '{path}'"))?;
|
||||
local_paths.insert(resolved_path);
|
||||
raw_paths.insert(absolute_path);
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
raw_paths.into_iter().collect(),
|
||||
local_paths.into_iter().collect(),
|
||||
remote_urls.into_iter().collect(),
|
||||
external_cmds.into_iter().collect(),
|
||||
protocol_paths.into_iter().collect(),
|
||||
with_last_reply,
|
||||
))
|
||||
}
|
||||
|
||||
async fn load_documents(
|
||||
loaders: &HashMap<String, String>,
|
||||
local_paths: Vec<String>,
|
||||
remote_urls: Vec<String>,
|
||||
external_cmds: Vec<String>,
|
||||
protocol_paths: Vec<String>,
|
||||
) -> Result<(
|
||||
Vec<(&'static str, String, String)>,
|
||||
Vec<String>,
|
||||
HashMap<String, String>,
|
||||
)> {
|
||||
let mut files = vec![];
|
||||
let mut medias = vec![];
|
||||
let mut data_urls = HashMap::new();
|
||||
|
||||
for cmd in external_cmds {
|
||||
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
|
||||
.stderr_to_stdout()
|
||||
.unchecked()
|
||||
.read()
|
||||
.unwrap_or_else(|err| err.to_string());
|
||||
files.push(("CMD", cmd, output));
|
||||
}
|
||||
|
||||
let local_files = expand_glob_paths(&local_paths, true).await?;
|
||||
for file_path in local_files {
|
||||
if is_image(&file_path) {
|
||||
let contents = read_media_to_data_url(&file_path)
|
||||
.with_context(|| format!("Unable to read media '{file_path}'"))?;
|
||||
data_urls.insert(sha256(&contents), file_path);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
let document = load_file(loaders, &file_path)
|
||||
.await
|
||||
.with_context(|| format!("Unable to read file '{file_path}'"))?;
|
||||
files.push(("FILE", file_path, document.contents));
|
||||
}
|
||||
}
|
||||
|
||||
for file_url in remote_urls {
|
||||
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load url '{file_url}'"))?;
|
||||
if extension == MEDIA_URL_EXTENSION {
|
||||
data_urls.insert(sha256(&contents), file_url);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
files.push(("URL", file_url, contents));
|
||||
}
|
||||
}
|
||||
|
||||
for protocol_path in protocol_paths {
|
||||
let documents = load_protocol_path(loaders, &protocol_path)
|
||||
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
|
||||
files.extend(
|
||||
documents
|
||||
.into_iter()
|
||||
.map(|document| ("FROM", document.path, document.contents)),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((files, medias, data_urls))
|
||||
}
|
||||
|
||||
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
|
||||
if data_url.starts_with("data:") {
|
||||
let hash = sha256(&data_url);
|
||||
if let Some(path) = data_urls.get(&hash) {
|
||||
return path.to_string();
|
||||
}
|
||||
data_url
|
||||
} else {
|
||||
data_url
|
||||
}
|
||||
}
|
||||
|
||||
fn is_image(path: &str) -> bool {
|
||||
get_patch_extension(path)
|
||||
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn read_media_to_data_url(image_path: &str) -> Result<String> {
|
||||
let extension = get_patch_extension(image_path).unwrap_or_default();
|
||||
let mime_type = match extension.as_str() {
|
||||
"png" => "image/png",
|
||||
"jpg" | "jpeg" => "image/jpeg",
|
||||
"webp" => "image/webp",
|
||||
"gif" => "image/gif",
|
||||
_ => bail!("Unexpected media type"),
|
||||
};
|
||||
let mut file = File::open(image_path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let encoded_image = base64_encode(buffer);
|
||||
let data_url = format!("data:{mime_type};base64,{encoded_image}");
|
||||
|
||||
Ok(data_url)
|
||||
}
|
||||
+3034
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,416 @@
|
||||
use super::*;
|
||||
|
||||
use crate::client::{Message, MessageContent, MessageRole, Model};
|
||||
|
||||
use anyhow::Result;
|
||||
use fancy_regex::Regex;
|
||||
use rust_embed::Embed;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub const SHELL_ROLE: &str = "shell";
|
||||
pub const EXPLAIN_SHELL_ROLE: &str = "explain-shell";
|
||||
pub const CODE_ROLE: &str = "code";
|
||||
pub const CREATE_TITLE_ROLE: &str = "create-title";
|
||||
|
||||
pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "assets/roles/"]
|
||||
struct RolesAsset;
|
||||
|
||||
static RE_METADATA: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?s)-{3,}\s*(.*?)\s*-{3,}\s*(.*)").unwrap());
|
||||
|
||||
pub trait RoleLike {
|
||||
fn to_role(&self) -> Role;
|
||||
fn model(&self) -> &Model;
|
||||
fn temperature(&self) -> Option<f64>;
|
||||
fn top_p(&self) -> Option<f64>;
|
||||
fn use_tools(&self) -> Option<String>;
|
||||
fn use_mcp_servers(&self) -> Option<String>;
|
||||
fn set_model(&mut self, model: Model);
|
||||
fn set_temperature(&mut self, value: Option<f64>);
|
||||
fn set_top_p(&mut self, value: Option<f64>);
|
||||
fn set_use_tools(&mut self, value: Option<String>);
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct Role {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
prompt: String,
|
||||
#[serde(
|
||||
rename(serialize = "model", deserialize = "model"),
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
model_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_tools: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_mcp_servers: Option<String>,
|
||||
|
||||
#[serde(skip)]
|
||||
model: Model,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn new(name: &str, content: &str) -> Self {
|
||||
let mut metadata = "";
|
||||
let mut prompt = content.trim();
|
||||
if let Ok(Some(caps)) = RE_METADATA.captures(content) {
|
||||
if let (Some(metadata_value), Some(prompt_value)) = (caps.get(1), caps.get(2)) {
|
||||
metadata = metadata_value.as_str().trim();
|
||||
prompt = prompt_value.as_str().trim();
|
||||
}
|
||||
}
|
||||
let mut prompt = prompt.to_string();
|
||||
interpolate_variables(&mut prompt);
|
||||
let mut role = Self {
|
||||
name: name.to_string(),
|
||||
prompt,
|
||||
..Default::default()
|
||||
};
|
||||
if !metadata.is_empty() {
|
||||
if let Ok(value) = serde_yaml::from_str::<Value>(metadata) {
|
||||
if let Some(value) = value.as_object() {
|
||||
for (key, value) in value {
|
||||
match key.as_str() {
|
||||
"model" => role.model_id = value.as_str().map(|v| v.to_string()),
|
||||
"temperature" => role.temperature = value.as_f64(),
|
||||
"top_p" => role.top_p = value.as_f64(),
|
||||
"use_tools" => role.use_tools = value.as_str().map(|v| v.to_string()),
|
||||
"use_mcp_servers" => {
|
||||
role.use_mcp_servers = value.as_str().map(|v| v.to_string())
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
role
|
||||
}
|
||||
|
||||
pub fn builtin(name: &str) -> Result<Self> {
|
||||
let content = RolesAsset::get(&format!("{name}.md"))
|
||||
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?;
|
||||
let content = unsafe { std::str::from_utf8_unchecked(&content.data) };
|
||||
Ok(Role::new(name, content))
|
||||
}
|
||||
|
||||
pub fn list_builtin_role_names() -> Vec<String> {
|
||||
RolesAsset::iter()
|
||||
.filter_map(|v| v.strip_suffix(".md").map(|v| v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn list_builtin_roles() -> Vec<Self> {
|
||||
RolesAsset::iter()
|
||||
.filter_map(|v| Role::builtin(&v).ok())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn has_args(&self) -> bool {
|
||||
self.name.contains('#')
|
||||
}
|
||||
|
||||
pub fn export(&self) -> String {
|
||||
let mut metadata = vec![];
|
||||
if let Some(model) = self.model_id() {
|
||||
metadata.push(format!("model: {model}"));
|
||||
}
|
||||
if let Some(temperature) = self.temperature() {
|
||||
metadata.push(format!("temperature: {temperature}"));
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
metadata.push(format!("top_p: {top_p}"));
|
||||
}
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
metadata.push(format!("use_tools: {use_tools}"));
|
||||
}
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
metadata.push(format!("use_mcp_servers: {use_mcp_servers}"));
|
||||
}
|
||||
if metadata.is_empty() {
|
||||
format!("{}\n", self.prompt)
|
||||
} else if self.prompt.is_empty() {
|
||||
format!("---\n{}\n---\n", metadata.join("\n"))
|
||||
} else {
|
||||
format!("---\n{}\n---\n\n{}\n", metadata.join("\n"), self.prompt)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save(&mut self, role_name: &str, role_path: &Path, is_repl: bool) -> Result<()> {
|
||||
ensure_parent_exists(role_path)?;
|
||||
|
||||
let content = self.export();
|
||||
std::fs::write(role_path, content).with_context(|| {
|
||||
format!(
|
||||
"Failed to write role {} to {}",
|
||||
self.name,
|
||||
role_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if is_repl {
|
||||
println!("✓ Saved role to '{}'.", role_path.display());
|
||||
}
|
||||
|
||||
if role_name != self.name {
|
||||
self.name = role_name.to_string();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn sync<T: RoleLike>(&mut self, role_like: &T) {
|
||||
let model = role_like.model();
|
||||
let temperature = role_like.temperature();
|
||||
let top_p = role_like.top_p();
|
||||
let use_tools = role_like.use_tools();
|
||||
let use_mcp_servers = role_like.use_mcp_servers();
|
||||
self.batch_set(model, temperature, top_p, use_tools, use_mcp_servers);
|
||||
}
|
||||
|
||||
pub fn batch_set(
|
||||
&mut self,
|
||||
model: &Model,
|
||||
temperature: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
use_tools: Option<String>,
|
||||
use_mcp_servers: Option<String>,
|
||||
) {
|
||||
self.set_model(model.clone());
|
||||
if temperature.is_some() {
|
||||
self.set_temperature(temperature);
|
||||
}
|
||||
if top_p.is_some() {
|
||||
self.set_top_p(top_p);
|
||||
}
|
||||
if use_tools.is_some() {
|
||||
self.set_use_tools(use_tools);
|
||||
}
|
||||
if use_mcp_servers.is_some() {
|
||||
self.set_use_mcp_servers(use_mcp_servers);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_derived(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn model_id(&self) -> Option<&str> {
|
||||
self.model_id.as_deref()
|
||||
}
|
||||
|
||||
pub fn prompt(&self) -> &str {
|
||||
&self.prompt
|
||||
}
|
||||
|
||||
pub fn is_empty_prompt(&self) -> bool {
|
||||
self.prompt.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_embedded_prompt(&self) -> bool {
|
||||
self.prompt.contains(INPUT_PLACEHOLDER)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, input: &Input) -> String {
|
||||
let input_markdown = input.render();
|
||||
if self.is_empty_prompt() {
|
||||
input_markdown
|
||||
} else if self.is_embedded_prompt() {
|
||||
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
|
||||
} else {
|
||||
format!("{}\n\n{}", self.prompt, input_markdown)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
|
||||
let mut content = input.message_content();
|
||||
let mut messages = if self.is_empty_prompt() {
|
||||
vec![Message::new(MessageRole::User, content)]
|
||||
} else if self.is_embedded_prompt() {
|
||||
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
|
||||
vec![Message::new(MessageRole::User, content)]
|
||||
} else {
|
||||
let mut messages = vec![];
|
||||
let (system, cases) = parse_structure_prompt(&self.prompt);
|
||||
if !system.is_empty() {
|
||||
messages.push(Message::new(
|
||||
MessageRole::System,
|
||||
MessageContent::Text(system.to_string()),
|
||||
));
|
||||
}
|
||||
if !cases.is_empty() {
|
||||
messages.extend(cases.into_iter().flat_map(|(i, o)| {
|
||||
vec![
|
||||
Message::new(MessageRole::User, MessageContent::Text(i.to_string())),
|
||||
Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())),
|
||||
]
|
||||
}));
|
||||
}
|
||||
messages.push(Message::new(MessageRole::User, content));
|
||||
messages
|
||||
};
|
||||
if let Some(text) = input.continue_output() {
|
||||
messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::Text(text.into()),
|
||||
));
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Role {
|
||||
fn to_role(&self) -> Role {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.use_tools.clone()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.use_mcp_servers.clone()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
if !self.model().id().is_empty() {
|
||||
self.model_id = Some(model.id().to_string());
|
||||
}
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
self.temperature = value;
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
self.top_p = value;
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
self.use_tools = value;
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
self.use_mcp_servers = value;
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
|
||||
let mut text = prompt;
|
||||
let mut search_input = true;
|
||||
let mut system = None;
|
||||
let mut parts = vec![];
|
||||
loop {
|
||||
let search = if search_input {
|
||||
"### INPUT:"
|
||||
} else {
|
||||
"### OUTPUT:"
|
||||
};
|
||||
match text.find(search) {
|
||||
Some(idx) => {
|
||||
if system.is_none() {
|
||||
system = Some(&text[..idx])
|
||||
} else {
|
||||
parts.push(&text[..idx])
|
||||
}
|
||||
search_input = !search_input;
|
||||
text = &text[(idx + search.len())..];
|
||||
}
|
||||
None => {
|
||||
if !text.is_empty() {
|
||||
if system.is_none() {
|
||||
system = Some(text)
|
||||
} else {
|
||||
parts.push(text)
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let parts_len = parts.len();
|
||||
if parts_len > 0 && parts_len % 2 == 0 {
|
||||
let cases: Vec<(&str, &str)> = parts
|
||||
.iter()
|
||||
.step_by(2)
|
||||
.zip(parts.iter().skip(1).step_by(2))
|
||||
.map(|(i, o)| (i.trim(), o.trim()))
|
||||
.collect();
|
||||
let system = system.map(|v| v.trim()).unwrap_or_default();
|
||||
return (system, cases);
|
||||
}
|
||||
|
||||
(prompt, vec![])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt1() {
|
||||
let prompt = r#"
|
||||
System message
|
||||
### INPUT:
|
||||
Input 1
|
||||
### OUTPUT:
|
||||
Output 1
|
||||
"#;
|
||||
assert_eq!(
|
||||
parse_structure_prompt(prompt),
|
||||
("System message", vec![("Input 1", "Output 1")])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt2() {
|
||||
let prompt = r#"
|
||||
### INPUT:
|
||||
Input 1
|
||||
### OUTPUT:
|
||||
Output 1
|
||||
"#;
|
||||
assert_eq!(
|
||||
parse_structure_prompt(prompt),
|
||||
("", vec![("Input 1", "Output 1")])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt3() {
|
||||
let prompt = r#"
|
||||
System message
|
||||
### INPUT:
|
||||
Input 1
|
||||
"#;
|
||||
assert_eq!(parse_structure_prompt(prompt), (prompt, vec![]));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,659 @@
|
||||
use super::input::*;
|
||||
use super::*;
|
||||
|
||||
use crate::client::{Message, MessageContent, MessageRole};
|
||||
use crate::render::MarkdownRender;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use inquire::{validator::Validation, Confirm, Text};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{read_to_string, write};
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static RE_AUTONAME_PREFIX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap());
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct Session {
|
||||
#[serde(rename(serialize = "model", deserialize = "model"))]
|
||||
model_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_tools: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_mcp_servers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
save_session: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
compress_threshold: Option<usize>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role_name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
|
||||
agent_variables: AgentVariables,
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
agent_instructions: String,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
compressed_messages: Vec<Message>,
|
||||
messages: Vec<Message>,
|
||||
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
|
||||
data_urls: HashMap<String, String>,
|
||||
|
||||
#[serde(skip)]
|
||||
model: Model,
|
||||
#[serde(skip)]
|
||||
role_prompt: String,
|
||||
#[serde(skip)]
|
||||
name: String,
|
||||
#[serde(skip)]
|
||||
path: Option<String>,
|
||||
#[serde(skip)]
|
||||
dirty: bool,
|
||||
#[serde(skip)]
|
||||
save_session_this_time: bool,
|
||||
#[serde(skip)]
|
||||
compressing: bool,
|
||||
#[serde(skip)]
|
||||
autoname: Option<AutoName>,
|
||||
#[serde(skip)]
|
||||
tokens: usize,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(config: &Config, name: &str) -> Self {
|
||||
let role = config.extract_role();
|
||||
let mut session = Self {
|
||||
name: name.to_string(),
|
||||
save_session: config.save_session,
|
||||
..Default::default()
|
||||
};
|
||||
session.set_role(role);
|
||||
session.dirty = false;
|
||||
session
|
||||
}
|
||||
|
||||
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
|
||||
let content = read_to_string(path)
|
||||
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
|
||||
let mut session: Self =
|
||||
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?;
|
||||
|
||||
session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?;
|
||||
|
||||
if let Some(autoname) = name.strip_prefix("_/") {
|
||||
session.name = TEMP_SESSION_NAME.to_string();
|
||||
session.path = None;
|
||||
if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) {
|
||||
session.autoname = Some(AutoName::new(autoname[16..].to_string()));
|
||||
}
|
||||
} else {
|
||||
session.name = name.to_string();
|
||||
session.path = Some(path.display().to_string());
|
||||
}
|
||||
|
||||
if let Some(role_name) = &session.role_name {
|
||||
if let Ok(role) = config.retrieve_role(role_name) {
|
||||
session.role_prompt = role.prompt().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
session.update_tokens();
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.messages.is_empty() && self.compressed_messages.is_empty()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn role_name(&self) -> Option<&str> {
|
||||
self.role_name.as_deref()
|
||||
}
|
||||
|
||||
pub fn dirty(&self) -> bool {
|
||||
self.dirty
|
||||
}
|
||||
|
||||
pub fn save_session(&self) -> Option<bool> {
|
||||
self.save_session
|
||||
}
|
||||
|
||||
pub fn tokens(&self) -> usize {
|
||||
self.tokens
|
||||
}
|
||||
|
||||
pub fn update_tokens(&mut self) {
|
||||
self.tokens = self.model().total_tokens(&self.messages);
|
||||
}
|
||||
|
||||
pub fn has_user_messages(&self) -> bool {
|
||||
self.messages.iter().any(|v| v.role.is_user())
|
||||
}
|
||||
|
||||
pub fn user_messages_len(&self) -> usize {
|
||||
self.messages.iter().filter(|v| v.role.is_user()).count()
|
||||
}
|
||||
|
||||
pub fn export(&self) -> Result<String> {
|
||||
let mut data = json!({
|
||||
"path": self.path,
|
||||
"model": self.model().id(),
|
||||
});
|
||||
if let Some(temperature) = self.temperature() {
|
||||
data["temperature"] = temperature.into();
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
data["top_p"] = top_p.into();
|
||||
}
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
data["use_tools"] = use_tools.into();
|
||||
}
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
data["use_mcp_servers"] = use_mcp_servers.into();
|
||||
}
|
||||
if let Some(save_session) = self.save_session() {
|
||||
data["save_session"] = save_session.into();
|
||||
}
|
||||
let (tokens, percent) = self.tokens_usage();
|
||||
data["total_tokens"] = tokens.into();
|
||||
if let Some(max_input_tokens) = self.model().max_input_tokens() {
|
||||
data["max_input_tokens"] = max_input_tokens.into();
|
||||
}
|
||||
if percent != 0.0 {
|
||||
data["total/max"] = format!("{percent}%").into();
|
||||
}
|
||||
data["messages"] = json!(self.messages);
|
||||
|
||||
let output = serde_yaml::to_string(&data)
|
||||
.with_context(|| format!("Unable to show info about session '{}'", &self.name))?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn render(
|
||||
&self,
|
||||
render: &mut MarkdownRender,
|
||||
agent_info: &Option<(String, Vec<String>)>,
|
||||
) -> Result<String> {
|
||||
let mut items = vec![];
|
||||
|
||||
if let Some(path) = &self.path {
|
||||
items.push(("path", path.to_string()));
|
||||
}
|
||||
|
||||
if let Some(autoname) = self.autoname() {
|
||||
items.push(("autoname", autoname.to_string()));
|
||||
}
|
||||
|
||||
items.push(("model", self.model().id()));
|
||||
|
||||
if let Some(temperature) = self.temperature() {
|
||||
items.push(("temperature", temperature.to_string()));
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
items.push(("top_p", top_p.to_string()));
|
||||
}
|
||||
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
items.push(("use_tools", use_tools));
|
||||
}
|
||||
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
items.push(("use_mcp_servers", use_mcp_servers));
|
||||
}
|
||||
|
||||
if let Some(save_session) = self.save_session() {
|
||||
items.push(("save_session", save_session.to_string()));
|
||||
}
|
||||
|
||||
if let Some(compress_threshold) = self.compress_threshold {
|
||||
items.push(("compress_threshold", compress_threshold.to_string()));
|
||||
}
|
||||
|
||||
if let Some(max_input_tokens) = self.model().max_input_tokens() {
|
||||
items.push(("max_input_tokens", max_input_tokens.to_string()));
|
||||
}
|
||||
|
||||
let mut lines: Vec<String> = items
|
||||
.iter()
|
||||
.map(|(name, value)| format!("{name:<20}{value}"))
|
||||
.collect();
|
||||
|
||||
lines.push(String::new());
|
||||
|
||||
if !self.is_empty() {
|
||||
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
|
||||
|
||||
for message in &self.messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
lines.push(
|
||||
render
|
||||
.render(&message.content.render_input(resolve_url_fn, agent_info)),
|
||||
);
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
if let MessageContent::Text(text) = &message.content {
|
||||
lines.push(render.render(text));
|
||||
}
|
||||
lines.push("".into());
|
||||
}
|
||||
MessageRole::User => {
|
||||
lines.push(format!(
|
||||
">> {}",
|
||||
message.content.render_input(resolve_url_fn, agent_info)
|
||||
));
|
||||
}
|
||||
MessageRole::Tool => {
|
||||
lines.push(message.content.render_input(resolve_url_fn, agent_info));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
pub fn tokens_usage(&self) -> (usize, f32) {
|
||||
let tokens = self.tokens();
|
||||
let max_input_tokens = self.model().max_input_tokens().unwrap_or_default();
|
||||
let percent = if max_input_tokens == 0 {
|
||||
0.0
|
||||
} else {
|
||||
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
|
||||
(percent * 100.0).round() / 100.0
|
||||
};
|
||||
(tokens, percent)
|
||||
}
|
||||
|
||||
pub fn set_role(&mut self, role: Role) {
|
||||
self.model_id = role.model().id();
|
||||
self.temperature = role.temperature();
|
||||
self.top_p = role.top_p();
|
||||
self.use_tools = role.use_tools();
|
||||
self.use_mcp_servers = role.use_mcp_servers();
|
||||
self.model = role.model().clone();
|
||||
self.role_name = convert_option_string(role.name());
|
||||
self.role_prompt = role.prompt().to_string();
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn clear_role(&mut self) {
|
||||
self.role_name = None;
|
||||
self.role_prompt.clear();
|
||||
}
|
||||
|
||||
pub fn sync_agent(&mut self, agent: &Agent) {
|
||||
self.role_name = None;
|
||||
self.role_prompt = agent.interpolated_instructions();
|
||||
self.agent_variables = agent.variables().clone();
|
||||
self.agent_instructions = self.role_prompt.clone();
|
||||
}
|
||||
|
||||
pub fn agent_variables(&self) -> &AgentVariables {
|
||||
&self.agent_variables
|
||||
}
|
||||
|
||||
pub fn agent_instructions(&self) -> &str {
|
||||
&self.agent_instructions
|
||||
}
|
||||
|
||||
pub fn set_save_session(&mut self, value: Option<bool>) {
|
||||
if self.save_session != value {
|
||||
self.save_session = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_save_session_this_time(&mut self) {
|
||||
self.save_session_this_time = true;
|
||||
}
|
||||
|
||||
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
|
||||
if self.compress_threshold != value {
|
||||
self.compress_threshold = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn need_compress(&self, global_compress_threshold: usize) -> bool {
|
||||
if self.compressing {
|
||||
return false;
|
||||
}
|
||||
let threshold = self.compress_threshold.unwrap_or(global_compress_threshold);
|
||||
if threshold < 1 {
|
||||
return false;
|
||||
}
|
||||
self.tokens() > threshold
|
||||
}
|
||||
|
||||
pub fn compressing(&self) -> bool {
|
||||
self.compressing
|
||||
}
|
||||
|
||||
pub fn set_compressing(&mut self, compressing: bool) {
|
||||
self.compressing = compressing;
|
||||
}
|
||||
|
||||
pub fn compress(&mut self, mut prompt: String) {
|
||||
if let Some(system_prompt) = self.messages.first().and_then(|v| {
|
||||
if MessageRole::System == v.role {
|
||||
let content = v.content.to_text();
|
||||
if !content.is_empty() {
|
||||
return Some(content);
|
||||
}
|
||||
}
|
||||
None
|
||||
}) {
|
||||
prompt = format!("{system_prompt}\n\n{prompt}",);
|
||||
}
|
||||
self.compressed_messages.append(&mut self.messages);
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::System,
|
||||
MessageContent::Text(prompt),
|
||||
));
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn need_autoname(&self) -> bool {
|
||||
self.autoname.as_ref().map(|v| v.need()).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn set_autonaming(&mut self, naming: bool) {
|
||||
if let Some(v) = self.autoname.as_mut() {
|
||||
v.naming = naming;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chat_history_for_autonaming(&self) -> Option<String> {
|
||||
self.autoname.as_ref().and_then(|v| v.chat_history.clone())
|
||||
}
|
||||
|
||||
pub fn autoname(&self) -> Option<&str> {
|
||||
self.autoname.as_ref().and_then(|v| v.name.as_deref())
|
||||
}
|
||||
|
||||
pub fn set_autoname(&mut self, value: &str) {
|
||||
let name = value
|
||||
.chars()
|
||||
.map(|v| if v.is_alphanumeric() { v } else { '-' })
|
||||
.collect();
|
||||
self.autoname = Some(AutoName::new(name));
|
||||
}
|
||||
|
||||
pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> {
|
||||
let mut save_session = self.save_session();
|
||||
if self.save_session_this_time {
|
||||
save_session = Some(true);
|
||||
}
|
||||
if self.dirty && save_session != Some(false) {
|
||||
let mut session_dir = session_dir.to_path_buf();
|
||||
let mut session_name = self.name().to_string();
|
||||
if save_session.is_none() {
|
||||
if !is_repl {
|
||||
return Ok(());
|
||||
}
|
||||
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
|
||||
if !ans {
|
||||
return Ok(());
|
||||
}
|
||||
if session_name == TEMP_SESSION_NAME {
|
||||
session_name = Text::new("Session name:")
|
||||
.with_validator(|input: &str| {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
Ok(Validation::Invalid("This name is required".into()))
|
||||
} else if input == TEMP_SESSION_NAME {
|
||||
Ok(Validation::Invalid("This name is reserved".into()))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
}
|
||||
} else if save_session == Some(true) && session_name == TEMP_SESSION_NAME {
|
||||
session_dir = session_dir.join("_");
|
||||
ensure_parent_exists(&session_dir).with_context(|| {
|
||||
format!("Failed to create directory '{}'", session_dir.display())
|
||||
})?;
|
||||
|
||||
let now = chrono::Local::now();
|
||||
session_name = now.format("%Y%m%dT%H%M%S").to_string();
|
||||
if let Some(autoname) = self.autoname() {
|
||||
session_name = format!("{session_name}-{autoname}")
|
||||
}
|
||||
}
|
||||
let session_path = session_dir.join(format!("{session_name}.yaml"));
|
||||
self.save(&session_name, &session_path, is_repl)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> {
|
||||
ensure_parent_exists(session_path)?;
|
||||
|
||||
self.path = Some(session_path.display().to_string());
|
||||
|
||||
let content = serde_yaml::to_string(&self)
|
||||
.with_context(|| format!("Failed to serde session '{}'", self.name))?;
|
||||
write(session_path, content).with_context(|| {
|
||||
format!(
|
||||
"Failed to write session '{}' to '{}'",
|
||||
self.name,
|
||||
session_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if is_repl {
|
||||
println!("✓ Saved the session to '{}'.", session_path.display());
|
||||
}
|
||||
|
||||
if self.name() != session_name {
|
||||
self.name = session_name.to_string()
|
||||
}
|
||||
|
||||
self.dirty = false;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn guard_empty(&self) -> Result<()> {
|
||||
if !self.is_empty() {
|
||||
bail!("Cannot perform this operation because the session has messages, please `.empty session` first.");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
|
||||
if input.continue_output().is_some() {
|
||||
if let Some(message) = self.messages.last_mut() {
|
||||
if let MessageContent::Text(text) = &mut message.content {
|
||||
*text = format!("{text}{output}");
|
||||
}
|
||||
}
|
||||
} else if input.regenerate() {
|
||||
if let Some(message) = self.messages.last_mut() {
|
||||
if let MessageContent::Text(text) = &mut message.content {
|
||||
*text = output.to_string();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if self.messages.is_empty() {
|
||||
if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) {
|
||||
let raw_input = input.raw();
|
||||
let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n");
|
||||
self.autoname = Some(AutoName::new_from_chat_history(chat_history));
|
||||
}
|
||||
self.messages.extend(input.role().build_messages(input));
|
||||
} else {
|
||||
self.messages
|
||||
.push(Message::new(MessageRole::User, input.message_content()));
|
||||
}
|
||||
self.data_urls.extend(input.data_urls());
|
||||
if let Some(tool_calls) = input.tool_calls() {
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::Tool,
|
||||
MessageContent::ToolCalls(tool_calls.clone()),
|
||||
))
|
||||
}
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::Text(output.to_string()),
|
||||
));
|
||||
}
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
self.compressed_messages.clear();
|
||||
self.data_urls.clear();
|
||||
self.autoname = None;
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, input: &Input) -> String {
|
||||
let messages = self.build_messages(input);
|
||||
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
|
||||
let mut messages = self.messages.clone();
|
||||
if input.continue_output().is_some() {
|
||||
return messages;
|
||||
} else if input.regenerate() {
|
||||
while let Some(last) = messages.last() {
|
||||
if !last.role.is_user() {
|
||||
messages.pop();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
let mut need_add_msg = true;
|
||||
let len = messages.len();
|
||||
if len == 0 {
|
||||
messages = input.role().build_messages(input);
|
||||
need_add_msg = false;
|
||||
} else if len == 1 && self.compressed_messages.len() >= 2 {
|
||||
if let Some(index) = self
|
||||
.compressed_messages
|
||||
.iter()
|
||||
.rposition(|v| v.role == MessageRole::User)
|
||||
{
|
||||
messages.extend(self.compressed_messages[index..].to_vec());
|
||||
}
|
||||
}
|
||||
if need_add_msg {
|
||||
messages.push(Message::new(MessageRole::User, input.message_content()));
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Session {
|
||||
fn to_role(&self) -> Role {
|
||||
let role_name = self.role_name.as_deref().unwrap_or_default();
|
||||
let mut role = Role::new(role_name, &self.role_prompt);
|
||||
role.sync(self);
|
||||
role
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.use_tools.clone()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.use_mcp_servers.clone()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
if self.model().id() != model.id() {
|
||||
self.model_id = model.id();
|
||||
self.model = model;
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
if self.temperature != value {
|
||||
self.temperature = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
if self.top_p != value {
|
||||
self.top_p = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
if self.use_tools != value {
|
||||
self.use_tools = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
if self.use_mcp_servers != value {
|
||||
self.use_mcp_servers = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct AutoName {
|
||||
naming: bool,
|
||||
chat_history: Option<String>,
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
impl AutoName {
|
||||
pub fn new(name: String) -> Self {
|
||||
Self {
|
||||
name: Some(name),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn new_from_chat_history(chat_history: String) -> Self {
|
||||
Self {
|
||||
chat_history: Some(chat_history),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn need(&self) -> bool {
|
||||
!self.naming && self.chat_history.is_some() && self.name.is_none()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user