use super::input::*; use super::*; use crate::client::{Message, MessageContent, MessageRole}; use crate::render::MarkdownRender; use anyhow::{Context, Result, bail}; use fancy_regex::Regex; use inquire::{Confirm, Text, validator::Validation}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use std::fs::{read_to_string, write}; use std::path::Path; use std::sync::LazyLock; static RE_AUTONAME_PREFIX: LazyLock = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap()); #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct Session { #[serde(rename(serialize = "model", deserialize = "model"))] model_id: String, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] enabled_tools: Option, #[serde(skip_serializing_if = "Option::is_none")] enabled_mcp_servers: Option, #[serde(skip_serializing_if = "Option::is_none")] save_session: Option, #[serde(skip_serializing_if = "Option::is_none")] compression_threshold: Option, #[serde(skip_serializing_if = "Option::is_none")] role_name: Option, #[serde(default, skip_serializing_if = "IndexMap::is_empty")] agent_variables: AgentVariables, #[serde(default, skip_serializing_if = "String::is_empty")] agent_instructions: String, #[serde(default, skip_serializing_if = "Vec::is_empty")] compressed_messages: Vec, messages: Vec, #[serde(default, skip_serializing_if = "HashMap::is_empty")] data_urls: HashMap, #[serde(skip)] model: Model, #[serde(skip)] role_prompt: String, #[serde(skip)] name: String, #[serde(skip)] path: Option, #[serde(skip)] dirty: bool, #[serde(skip)] save_session_this_time: bool, #[serde(skip)] compressing: bool, #[serde(skip)] autoname: Option, #[serde(skip)] tokens: usize, } impl Session { pub fn new(config: &Config, name: &str) -> Self { let role = config.extract_role(); let mut session = Self { name: name.to_string(), save_session: config.save_session, ..Default::default() }; session.set_role(role); session.dirty = false; session } pub fn load(config: &Config, name: &str, path: &Path) -> Result { let content = read_to_string(path) .with_context(|| format!("Failed to load session {} at {}", name, path.display()))?; let mut session: Self = serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?; session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?; if let Some(autoname) = name.strip_prefix("_/") { session.name = TEMP_SESSION_NAME.to_string(); session.path = None; if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) { session.autoname = Some(AutoName::new(autoname[16..].to_string())); } } else { session.name = name.to_string(); session.path = Some(path.display().to_string()); } if let Some(role_name) = &session.role_name && let Ok(role) = config.retrieve_role(role_name) { session.role_prompt = role.prompt().to_string(); } session.update_tokens(); Ok(session) } pub fn is_empty(&self) -> bool { self.messages.is_empty() && self.compressed_messages.is_empty() } pub fn name(&self) -> &str { &self.name } pub fn role_name(&self) -> Option<&str> { self.role_name.as_deref() } pub fn dirty(&self) -> bool { self.dirty } pub fn save_session(&self) -> Option { self.save_session } pub fn tokens(&self) -> usize { self.tokens } pub fn update_tokens(&mut self) { self.tokens = self.model().total_tokens(&self.messages); } pub fn has_user_messages(&self) -> bool { self.messages.iter().any(|v| v.role.is_user()) } pub fn user_messages_len(&self) -> usize { self.messages.iter().filter(|v| v.role.is_user()).count() } pub fn export(&self) -> Result { let mut data = json!({ "path": self.path, "model": self.model().id(), }); if let Some(temperature) = self.temperature() { data["temperature"] = temperature.into(); } if let Some(top_p) = self.top_p() { data["top_p"] = top_p.into(); } if let Some(enabled_tools) = self.enabled_tools() { data["enabled_tools"] = enabled_tools.into(); } if let Some(enabled_mcp_servers) = self.enabled_mcp_servers() { data["enabled_mcp_servers"] = enabled_mcp_servers.into(); } if let Some(save_session) = self.save_session() { data["save_session"] = save_session.into(); } let (tokens, percent) = self.tokens_usage(); data["total_tokens"] = tokens.into(); if let Some(max_input_tokens) = self.model().max_input_tokens() { data["max_input_tokens"] = max_input_tokens.into(); } if percent != 0.0 { data["total/max"] = format!("{percent}%").into(); } data["messages"] = json!(self.messages); let output = serde_yaml::to_string(&data) .with_context(|| format!("Unable to show info about session '{}'", &self.name))?; Ok(output) } pub fn render( &self, render: &mut MarkdownRender, agent_info: &Option<(String, Vec)>, ) -> Result { let mut items = vec![]; if let Some(path) = &self.path { items.push(("path", path.to_string())); } if let Some(autoname) = self.autoname() { items.push(("autoname", autoname.to_string())); } items.push(("model", self.model().id())); if let Some(temperature) = self.temperature() { items.push(("temperature", temperature.to_string())); } if let Some(top_p) = self.top_p() { items.push(("top_p", top_p.to_string())); } if let Some(enabled_tools) = self.enabled_tools() { items.push(("enabled_tools", enabled_tools)); } if let Some(enabled_mcp_servers) = self.enabled_mcp_servers() { items.push(("enabled_mcp_servers", enabled_mcp_servers)); } if let Some(save_session) = self.save_session() { items.push(("save_session", save_session.to_string())); } if let Some(compression_threshold) = self.compression_threshold { items.push(("compression_threshold", compression_threshold.to_string())); } if let Some(max_input_tokens) = self.model().max_input_tokens() { items.push(("max_input_tokens", max_input_tokens.to_string())); } let mut lines: Vec = items .iter() .map(|(name, value)| format!("{name:<20}{value}")) .collect(); lines.push(String::new()); if !self.is_empty() { let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string()); for message in &self.messages { match message.role { MessageRole::System => { lines.push( render .render(&message.content.render_input(resolve_url_fn, agent_info)), ); } MessageRole::Assistant => { if let MessageContent::Text(text) = &message.content { lines.push(render.render(text)); } lines.push("".into()); } MessageRole::User => { lines.push(format!( ">> {}", message.content.render_input(resolve_url_fn, agent_info) )); } MessageRole::Tool => { lines.push(message.content.render_input(resolve_url_fn, agent_info)); } } } } Ok(lines.join("\n")) } pub fn tokens_usage(&self) -> (usize, f32) { let tokens = self.tokens(); let max_input_tokens = self.model().max_input_tokens().unwrap_or_default(); let percent = if max_input_tokens == 0 { 0.0 } else { let percent = tokens as f32 / max_input_tokens as f32 * 100.0; (percent * 100.0).round() / 100.0 }; (tokens, percent) } pub fn set_role(&mut self, role: Role) { self.model_id = role.model().id(); self.temperature = role.temperature(); self.top_p = role.top_p(); self.enabled_tools = role.enabled_tools(); self.enabled_mcp_servers = role.enabled_mcp_servers(); self.model = role.model().clone(); self.role_name = convert_option_string(role.name()); self.role_prompt = role.prompt().to_string(); self.dirty = true; self.update_tokens(); } pub fn clear_role(&mut self) { self.role_name = None; self.role_prompt.clear(); } pub fn sync_agent(&mut self, agent: &Agent) { self.role_name = None; self.role_prompt = agent.interpolated_instructions(); self.agent_variables = agent.variables().clone(); self.agent_instructions = self.role_prompt.clone(); if let Some(threshold) = agent.compression_threshold() { self.set_compression_threshold(Some(threshold)); } } pub fn agent_variables(&self) -> &AgentVariables { &self.agent_variables } pub fn agent_instructions(&self) -> &str { &self.agent_instructions } pub fn set_save_session(&mut self, value: Option) { if self.save_session != value { self.save_session = value; self.dirty = true; } } pub fn set_save_session_this_time(&mut self) { self.save_session_this_time = true; } pub fn set_compression_threshold(&mut self, value: Option) { if self.compression_threshold != value { self.compression_threshold = value; self.dirty = true; } } pub fn needs_compression(&self, global_compression_threshold: usize) -> bool { if self.compressing { return false; } let threshold = self .compression_threshold .unwrap_or(global_compression_threshold); if threshold < 1 { return false; } self.tokens() > threshold } pub fn compressing(&self) -> bool { self.compressing } pub fn set_compressing(&mut self, compressing: bool) { self.compressing = compressing; } pub fn compress(&mut self, mut prompt: String) { if let Some(system_prompt) = self.messages.first().and_then(|v| { if MessageRole::System == v.role { let content = v.content.to_text(); if !content.is_empty() { return Some(content); } } None }) { prompt = format!("{system_prompt}\n\n{prompt}",); } self.compressed_messages.append(&mut self.messages); self.messages.push(Message::new( MessageRole::System, MessageContent::Text(prompt), )); self.dirty = true; self.update_tokens(); } pub fn need_autoname(&self) -> bool { self.autoname.as_ref().map(|v| v.need()).unwrap_or_default() } pub fn set_autonaming(&mut self, naming: bool) { if let Some(v) = self.autoname.as_mut() { v.naming = naming; } } pub fn chat_history_for_autonaming(&self) -> Option { self.autoname.as_ref().and_then(|v| v.chat_history.clone()) } pub fn autoname(&self) -> Option<&str> { self.autoname.as_ref().and_then(|v| v.name.as_deref()) } pub fn set_autoname(&mut self, value: &str) { let name = value .chars() .map(|v| if v.is_alphanumeric() { v } else { '-' }) .collect(); self.autoname = Some(AutoName::new(name)); } pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> { let mut save_session = self.save_session(); if self.save_session_this_time { save_session = Some(true); } if self.dirty && save_session != Some(false) { let mut session_dir = session_dir.to_path_buf(); let mut session_name = self.name().to_string(); if save_session.is_none() { if !is_repl { return Ok(()); } let ans = Confirm::new("Save session?").with_default(false).prompt()?; if !ans { return Ok(()); } if session_name == TEMP_SESSION_NAME { session_name = Text::new("Session name:") .with_validator(|input: &str| { let input = input.trim(); if input.is_empty() { Ok(Validation::Invalid("This name is required".into())) } else if input == TEMP_SESSION_NAME { Ok(Validation::Invalid("This name is reserved".into())) } else { Ok(Validation::Valid) } }) .prompt()?; } } else if save_session == Some(true) && session_name == TEMP_SESSION_NAME { session_dir = session_dir.join("_"); ensure_parent_exists(&session_dir).with_context(|| { format!("Failed to create directory '{}'", session_dir.display()) })?; let now = chrono::Local::now(); session_name = now.format("%Y%m%dT%H%M%S").to_string(); if let Some(autoname) = self.autoname() { session_name = format!("{session_name}-{autoname}") } } let session_path = session_dir.join(format!("{session_name}.yaml")); self.save(&session_name, &session_path, is_repl)?; } Ok(()) } pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> { ensure_parent_exists(session_path)?; self.path = Some(session_path.display().to_string()); let content = serde_yaml::to_string(&self) .with_context(|| format!("Failed to serde session '{}'", self.name))?; write(session_path, content).with_context(|| { format!( "Failed to write session '{}' to '{}'", self.name, session_path.display() ) })?; if is_repl { println!("✓ Saved the session to '{}'.", session_path.display()); } if self.name() != session_name { self.name = session_name.to_string() } self.dirty = false; Ok(()) } pub fn guard_empty(&self) -> Result<()> { if !self.is_empty() { bail!( "Cannot perform this operation because the session has messages, please `.empty session` first." ); } Ok(()) } pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> { if input.continue_output().is_some() { if let Some(message) = self.messages.last_mut() && let MessageContent::Text(text) = &mut message.content { *text = format!("{text}{output}"); } } else if input.regenerate() { if let Some(message) = self.messages.last_mut() && let MessageContent::Text(text) = &mut message.content { *text = output.to_string(); } } else { if self.messages.is_empty() { if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) { let raw_input = input.raw(); let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n"); self.autoname = Some(AutoName::new_from_chat_history(chat_history)); } self.messages.extend(input.role().build_messages(input)); } else { self.messages .push(Message::new(MessageRole::User, input.message_content())); } self.data_urls.extend(input.data_urls()); if let Some(tool_calls) = input.tool_calls() { self.messages.push(Message::new( MessageRole::Tool, MessageContent::ToolCalls(tool_calls.clone()), )) } self.messages.push(Message::new( MessageRole::Assistant, MessageContent::Text(output.to_string()), )); } self.dirty = true; self.update_tokens(); Ok(()) } pub fn clear_messages(&mut self) { self.messages.clear(); self.compressed_messages.clear(); self.data_urls.clear(); self.autoname = None; self.dirty = true; self.update_tokens(); } pub fn echo_messages(&self, input: &Input) -> String { let messages = self.build_messages(input); serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into()) } pub fn build_messages(&self, input: &Input) -> Vec { let mut messages = self.messages.clone(); if input.continue_output().is_some() { return messages; } else if input.regenerate() { while let Some(last) = messages.last() { if !last.role.is_user() { messages.pop(); } else { break; } } return messages; } let mut need_add_msg = true; let len = messages.len(); if len == 0 { messages = input.role().build_messages(input); need_add_msg = false; } else if len == 1 && self.compressed_messages.len() >= 2 && let Some(index) = self .compressed_messages .iter() .rposition(|v| v.role == MessageRole::User) { messages.extend(self.compressed_messages[index..].to_vec()); } if need_add_msg { messages.push(Message::new(MessageRole::User, input.message_content())); } messages } } impl RoleLike for Session { fn to_role(&self) -> Role { let role_name = self.role_name.as_deref().unwrap_or_default(); let mut role = Role::new(role_name, &self.role_prompt); role.sync(self); role } fn model(&self) -> &Model { &self.model } fn temperature(&self) -> Option { self.temperature } fn top_p(&self) -> Option { self.top_p } fn enabled_tools(&self) -> Option { self.enabled_tools.clone() } fn enabled_mcp_servers(&self) -> Option { self.enabled_mcp_servers.clone() } fn set_model(&mut self, model: Model) { if self.model().id() != model.id() { self.model_id = model.id(); self.model = model; self.dirty = true; self.update_tokens(); } } fn set_temperature(&mut self, value: Option) { if self.temperature != value { self.temperature = value; self.dirty = true; } } fn set_top_p(&mut self, value: Option) { if self.top_p != value { self.top_p = value; self.dirty = true; } } fn set_enabled_tools(&mut self, value: Option) { if self.enabled_tools != value { self.enabled_tools = value; self.dirty = true; } } fn set_enabled_mcp_servers(&mut self, value: Option) { if self.enabled_mcp_servers != value { self.enabled_mcp_servers = value; self.dirty = true; } } } #[derive(Debug, Clone, Default)] struct AutoName { naming: bool, chat_history: Option, name: Option, } impl AutoName { pub fn new(name: String) -> Self { Self { name: Some(name), ..Default::default() } } pub fn new_from_chat_history(chat_history: String) -> Self { Self { chat_history: Some(chat_history), ..Default::default() } } pub fn need(&self) -> bool { !self.naming && self.chat_history.is_some() && self.name.is_none() } }