546 lines
16 KiB
Rust
546 lines
16 KiB
Rust
use super::*;
|
|
|
|
use crate::client::{
|
|
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
|
|
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
|
|
};
|
|
use crate::function::ToolResult;
|
|
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
use indexmap::IndexSet;
|
|
use std::{collections::HashMap, fs::File, io::Read};
|
|
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
|
|
|
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
|
const SUMMARY_MAX_WIDTH: usize = 80;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Input {
|
|
config: GlobalConfig,
|
|
text: String,
|
|
raw: (String, Vec<String>),
|
|
patched_text: Option<String>,
|
|
last_reply: Option<String>,
|
|
continue_output: Option<String>,
|
|
regenerate: bool,
|
|
medias: Vec<String>,
|
|
data_urls: HashMap<String, String>,
|
|
tool_calls: Option<MessageContentToolCalls>,
|
|
role: Role,
|
|
rag_name: Option<String>,
|
|
with_session: bool,
|
|
with_agent: bool,
|
|
}
|
|
|
|
impl Input {
|
|
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
|
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
|
Self {
|
|
config: config.clone(),
|
|
text: text.to_string(),
|
|
raw: (text.to_string(), vec![]),
|
|
patched_text: None,
|
|
last_reply: None,
|
|
continue_output: None,
|
|
regenerate: false,
|
|
medias: Default::default(),
|
|
data_urls: Default::default(),
|
|
tool_calls: None,
|
|
role,
|
|
rag_name: None,
|
|
with_session,
|
|
with_agent,
|
|
}
|
|
}
|
|
|
|
pub async fn from_files(
|
|
config: &GlobalConfig,
|
|
raw_text: &str,
|
|
paths: Vec<String>,
|
|
role: Option<Role>,
|
|
) -> Result<Self> {
|
|
let loaders = config.read().document_loaders.clone();
|
|
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
|
resolve_paths(&loaders, paths)?;
|
|
let mut last_reply = None;
|
|
let (documents, medias, data_urls) = load_documents(
|
|
&loaders,
|
|
local_paths,
|
|
remote_urls,
|
|
external_cmds,
|
|
protocol_paths,
|
|
)
|
|
.await
|
|
.context("Failed to load files")?;
|
|
let mut texts = vec![];
|
|
if !raw_text.is_empty() {
|
|
texts.push(raw_text.to_string());
|
|
};
|
|
if with_last_reply {
|
|
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
|
if !output.is_empty() {
|
|
last_reply = Some(output.clone())
|
|
} else if let Some(v) = input.last_reply.as_ref() {
|
|
last_reply = Some(v.clone());
|
|
}
|
|
if let Some(v) = last_reply.clone() {
|
|
texts.push(format!("\n{v}"));
|
|
}
|
|
}
|
|
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
|
|
bail!("No last reply found");
|
|
}
|
|
}
|
|
let documents_len = documents.len();
|
|
for (kind, path, contents) in documents {
|
|
if documents_len == 1 && raw_text.is_empty() {
|
|
texts.push(format!("\n{contents}"));
|
|
} else {
|
|
texts.push(format!(
|
|
"\n============ {kind}: {path} ============\n{contents}"
|
|
));
|
|
}
|
|
}
|
|
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
|
Ok(Self {
|
|
config: config.clone(),
|
|
text: texts.join("\n"),
|
|
raw: (raw_text.to_string(), raw_paths),
|
|
patched_text: None,
|
|
last_reply,
|
|
continue_output: None,
|
|
regenerate: false,
|
|
medias,
|
|
data_urls,
|
|
tool_calls: Default::default(),
|
|
role,
|
|
rag_name: None,
|
|
with_session,
|
|
with_agent,
|
|
})
|
|
}
|
|
|
|
pub async fn from_files_with_spinner(
|
|
config: &GlobalConfig,
|
|
raw_text: &str,
|
|
paths: Vec<String>,
|
|
role: Option<Role>,
|
|
abort_signal: AbortSignal,
|
|
) -> Result<Self> {
|
|
abortable_run_with_spinner(
|
|
Input::from_files(config, raw_text, paths, role),
|
|
"Loading files",
|
|
abort_signal,
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.text.is_empty() && self.medias.is_empty()
|
|
}
|
|
|
|
pub fn data_urls(&self) -> HashMap<String, String> {
|
|
self.data_urls.clone()
|
|
}
|
|
|
|
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
|
|
&self.tool_calls
|
|
}
|
|
|
|
pub fn text(&self) -> String {
|
|
match self.patched_text.clone() {
|
|
Some(text) => text,
|
|
None => self.text.clone(),
|
|
}
|
|
}
|
|
|
|
pub fn clear_patch(&mut self) {
|
|
self.patched_text = None;
|
|
}
|
|
|
|
pub fn set_text(&mut self, text: String) {
|
|
self.text = text;
|
|
}
|
|
|
|
pub fn stream(&self) -> bool {
|
|
self.config.read().stream && !self.role().model().no_stream()
|
|
}
|
|
|
|
pub fn continue_output(&self) -> Option<&str> {
|
|
self.continue_output.as_deref()
|
|
}
|
|
|
|
pub fn set_continue_output(&mut self, output: &str) {
|
|
let output = match &self.continue_output {
|
|
Some(v) => format!("{v}{output}"),
|
|
None => output.to_string(),
|
|
};
|
|
self.continue_output = Some(output);
|
|
}
|
|
|
|
pub fn regenerate(&self) -> bool {
|
|
self.regenerate
|
|
}
|
|
|
|
pub fn set_regenerate(&mut self) {
|
|
let role = self.config.read().extract_role();
|
|
if role.name() == self.role().name() {
|
|
self.role = role;
|
|
}
|
|
self.regenerate = true;
|
|
self.tool_calls = None;
|
|
}
|
|
|
|
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
|
|
if self.text.is_empty() {
|
|
return Ok(());
|
|
}
|
|
let rag = self.config.read().rag.clone();
|
|
if let Some(rag) = rag {
|
|
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
|
self.patched_text = Some(result);
|
|
self.rag_name = Some(rag.name().to_string());
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn rag_name(&self) -> Option<&str> {
|
|
self.rag_name.as_deref()
|
|
}
|
|
|
|
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
|
|
match self.tool_calls.as_mut() {
|
|
Some(exist_tool_results) => {
|
|
exist_tool_results.merge(tool_results, output);
|
|
}
|
|
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
|
init_client(&self.config, Some(self.role().model().clone()))
|
|
}
|
|
|
|
pub async fn fetch_chat_text(&self) -> Result<String> {
|
|
let client = self.create_client()?;
|
|
let text = client.chat_completions(self.clone()).await?.text;
|
|
let text = strip_think_tag(&text).to_string();
|
|
Ok(text)
|
|
}
|
|
|
|
pub fn prepare_completion_data(
|
|
&self,
|
|
model: &Model,
|
|
stream: bool,
|
|
) -> Result<ChatCompletionsData> {
|
|
let mut messages = self.build_messages()?;
|
|
patch_messages(&mut messages, model);
|
|
model.guard_max_input_tokens(&messages)?;
|
|
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
|
let functions = self.config.read().select_functions(self.role());
|
|
if let Some(vec) = &functions {
|
|
for def in vec {
|
|
debug!("Function definition: {:?}", def.name);
|
|
}
|
|
}
|
|
Ok(ChatCompletionsData {
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
functions,
|
|
stream,
|
|
})
|
|
}
|
|
|
|
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
|
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
|
session.build_messages(self)
|
|
} else {
|
|
self.role().build_messages(self)
|
|
};
|
|
if let Some(tool_calls) = &self.tool_calls {
|
|
messages.push(Message::new(
|
|
MessageRole::Assistant,
|
|
MessageContent::ToolCalls(tool_calls.clone()),
|
|
))
|
|
}
|
|
Ok(messages)
|
|
}
|
|
|
|
pub fn echo_messages(&self) -> String {
|
|
if let Some(session) = self.session(&self.config.read().session) {
|
|
session.echo_messages(self)
|
|
} else {
|
|
self.role().echo_messages(self)
|
|
}
|
|
}
|
|
|
|
pub fn role(&self) -> &Role {
|
|
&self.role
|
|
}
|
|
|
|
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
|
|
if self.with_session {
|
|
session.as_ref()
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
|
|
if self.with_session {
|
|
session.as_mut()
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn with_agent(&self) -> bool {
|
|
self.with_agent
|
|
}
|
|
|
|
pub fn summary(&self) -> String {
|
|
let text: String = self
|
|
.text
|
|
.trim()
|
|
.chars()
|
|
.map(|c| if c.is_control() { ' ' } else { c })
|
|
.collect();
|
|
if text.width_cjk() > SUMMARY_MAX_WIDTH {
|
|
let mut sum_width = 0;
|
|
let mut chars = vec![];
|
|
for c in text.chars() {
|
|
sum_width += c.width_cjk().unwrap_or(1);
|
|
if sum_width > SUMMARY_MAX_WIDTH - 3 {
|
|
chars.extend(['.', '.', '.']);
|
|
break;
|
|
}
|
|
chars.push(c);
|
|
}
|
|
chars.into_iter().collect()
|
|
} else {
|
|
text
|
|
}
|
|
}
|
|
|
|
pub fn raw(&self) -> String {
|
|
let (text, files) = &self.raw;
|
|
let mut segments = files.to_vec();
|
|
if !segments.is_empty() {
|
|
segments.insert(0, ".file".into());
|
|
}
|
|
if !text.is_empty() {
|
|
if !segments.is_empty() {
|
|
segments.push("--".into());
|
|
}
|
|
segments.push(text.clone());
|
|
}
|
|
segments.join(" ")
|
|
}
|
|
|
|
pub fn render(&self) -> String {
|
|
let text = self.text();
|
|
if self.medias.is_empty() {
|
|
return text;
|
|
}
|
|
let tail_text = if text.is_empty() {
|
|
String::new()
|
|
} else {
|
|
format!(" -- {text}")
|
|
};
|
|
let files: Vec<String> = self
|
|
.medias
|
|
.iter()
|
|
.cloned()
|
|
.map(|url| resolve_data_url(&self.data_urls, url))
|
|
.collect();
|
|
format!(".file {}{}", files.join(" "), tail_text)
|
|
}
|
|
|
|
pub fn message_content(&self) -> MessageContent {
|
|
if self.medias.is_empty() {
|
|
MessageContent::Text(self.text())
|
|
} else {
|
|
let mut list: Vec<MessageContentPart> = self
|
|
.medias
|
|
.iter()
|
|
.cloned()
|
|
.map(|url| MessageContentPart::ImageUrl {
|
|
image_url: ImageUrl { url },
|
|
})
|
|
.collect();
|
|
if !self.text.is_empty() {
|
|
list.insert(0, MessageContentPart::Text { text: self.text() });
|
|
}
|
|
MessageContent::Array(list)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
|
match role {
|
|
Some(v) => (v, false, false),
|
|
None => (
|
|
config.extract_role(),
|
|
config.session.is_some(),
|
|
config.agent.is_some(),
|
|
),
|
|
}
|
|
}
|
|
|
|
type ResolvePathsOutput = (
|
|
Vec<String>,
|
|
Vec<String>,
|
|
Vec<String>,
|
|
Vec<String>,
|
|
Vec<String>,
|
|
bool,
|
|
);
|
|
|
|
fn resolve_paths(
|
|
loaders: &HashMap<String, String>,
|
|
paths: Vec<String>,
|
|
) -> Result<ResolvePathsOutput> {
|
|
let mut raw_paths = IndexSet::new();
|
|
let mut local_paths = IndexSet::new();
|
|
let mut remote_urls = IndexSet::new();
|
|
let mut external_cmds = IndexSet::new();
|
|
let mut protocol_paths = IndexSet::new();
|
|
let mut with_last_reply = false;
|
|
for path in paths {
|
|
if path == "%%" {
|
|
with_last_reply = true;
|
|
raw_paths.insert(path);
|
|
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
|
|
external_cmds.insert(path[1..path.len() - 1].to_string());
|
|
raw_paths.insert(path);
|
|
} else if is_url(&path) {
|
|
if path.strip_suffix("**").is_some() {
|
|
bail!("Invalid website '{path}'");
|
|
}
|
|
remote_urls.insert(path.clone());
|
|
raw_paths.insert(path);
|
|
} else if is_loader_protocol(loaders, &path) {
|
|
protocol_paths.insert(path.clone());
|
|
raw_paths.insert(path);
|
|
} else {
|
|
let resolved_path = resolve_home_dir(&path);
|
|
let absolute_path = to_absolute_path(&resolved_path)
|
|
.with_context(|| format!("Invalid path '{path}'"))?;
|
|
local_paths.insert(resolved_path);
|
|
raw_paths.insert(absolute_path);
|
|
}
|
|
}
|
|
Ok((
|
|
raw_paths.into_iter().collect(),
|
|
local_paths.into_iter().collect(),
|
|
remote_urls.into_iter().collect(),
|
|
external_cmds.into_iter().collect(),
|
|
protocol_paths.into_iter().collect(),
|
|
with_last_reply,
|
|
))
|
|
}
|
|
|
|
async fn load_documents(
|
|
loaders: &HashMap<String, String>,
|
|
local_paths: Vec<String>,
|
|
remote_urls: Vec<String>,
|
|
external_cmds: Vec<String>,
|
|
protocol_paths: Vec<String>,
|
|
) -> Result<(
|
|
Vec<(&'static str, String, String)>,
|
|
Vec<String>,
|
|
HashMap<String, String>,
|
|
)> {
|
|
let mut files = vec![];
|
|
let mut medias = vec![];
|
|
let mut data_urls = HashMap::new();
|
|
|
|
for cmd in external_cmds {
|
|
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
|
|
.stderr_to_stdout()
|
|
.unchecked()
|
|
.read()
|
|
.unwrap_or_else(|err| err.to_string());
|
|
files.push(("CMD", cmd, output));
|
|
}
|
|
|
|
let local_files = expand_glob_paths(&local_paths, true).await?;
|
|
for file_path in local_files {
|
|
if is_image(&file_path) {
|
|
let contents = read_media_to_data_url(&file_path)
|
|
.with_context(|| format!("Unable to read media '{file_path}'"))?;
|
|
data_urls.insert(sha256(&contents), file_path);
|
|
medias.push(contents)
|
|
} else {
|
|
let document = load_file(loaders, &file_path)
|
|
.await
|
|
.with_context(|| format!("Unable to read file '{file_path}'"))?;
|
|
files.push(("FILE", file_path, document.contents));
|
|
}
|
|
}
|
|
|
|
for file_url in remote_urls {
|
|
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
|
|
.await
|
|
.with_context(|| format!("Failed to load url '{file_url}'"))?;
|
|
if extension == MEDIA_URL_EXTENSION {
|
|
data_urls.insert(sha256(&contents), file_url);
|
|
medias.push(contents)
|
|
} else {
|
|
files.push(("URL", file_url, contents));
|
|
}
|
|
}
|
|
|
|
for protocol_path in protocol_paths {
|
|
let documents = load_protocol_path(loaders, &protocol_path)
|
|
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
|
|
files.extend(
|
|
documents
|
|
.into_iter()
|
|
.map(|document| ("FROM", document.path, document.contents)),
|
|
);
|
|
}
|
|
|
|
Ok((files, medias, data_urls))
|
|
}
|
|
|
|
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
|
|
if data_url.starts_with("data:") {
|
|
let hash = sha256(&data_url);
|
|
if let Some(path) = data_urls.get(&hash) {
|
|
return path.to_string();
|
|
}
|
|
data_url
|
|
} else {
|
|
data_url
|
|
}
|
|
}
|
|
|
|
fn is_image(path: &str) -> bool {
|
|
get_patch_extension(path)
|
|
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
fn read_media_to_data_url(image_path: &str) -> Result<String> {
|
|
let extension = get_patch_extension(image_path).unwrap_or_default();
|
|
let mime_type = match extension.as_str() {
|
|
"png" => "image/png",
|
|
"jpg" | "jpeg" => "image/jpeg",
|
|
"webp" => "image/webp",
|
|
"gif" => "image/gif",
|
|
_ => bail!("Unexpected media type"),
|
|
};
|
|
let mut file = File::open(image_path)?;
|
|
let mut buffer = Vec::new();
|
|
file.read_to_end(&mut buffer)?;
|
|
|
|
let encoded_image = base64_encode(buffer);
|
|
let data_url = format!("data:{mime_type};base64,{encoded_image}");
|
|
|
|
Ok(data_url)
|
|
}
|