Baseline project
This commit is contained in:
@@ -0,0 +1,545 @@
|
||||
use super::*;
|
||||
|
||||
use crate::client::{
|
||||
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
|
||||
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
|
||||
};
|
||||
use crate::function::ToolResult;
|
||||
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use indexmap::IndexSet;
|
||||
use std::{collections::HashMap, fs::File, io::Read};
|
||||
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
||||
|
||||
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
||||
const SUMMARY_MAX_WIDTH: usize = 80;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Input {
|
||||
config: GlobalConfig,
|
||||
text: String,
|
||||
raw: (String, Vec<String>),
|
||||
patched_text: Option<String>,
|
||||
last_reply: Option<String>,
|
||||
continue_output: Option<String>,
|
||||
regenerate: bool,
|
||||
medias: Vec<String>,
|
||||
data_urls: HashMap<String, String>,
|
||||
tool_calls: Option<MessageContentToolCalls>,
|
||||
role: Role,
|
||||
rag_name: Option<String>,
|
||||
with_session: bool,
|
||||
with_agent: bool,
|
||||
}
|
||||
|
||||
impl Input {
|
||||
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Self {
|
||||
config: config.clone(),
|
||||
text: text.to_string(),
|
||||
raw: (text.to_string(), vec![]),
|
||||
patched_text: None,
|
||||
last_reply: None,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias: Default::default(),
|
||||
data_urls: Default::default(),
|
||||
tool_calls: None,
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_files(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
) -> Result<Self> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
||||
resolve_paths(&loaders, paths)?;
|
||||
let mut last_reply = None;
|
||||
let (documents, medias, data_urls) = load_documents(
|
||||
&loaders,
|
||||
local_paths,
|
||||
remote_urls,
|
||||
external_cmds,
|
||||
protocol_paths,
|
||||
)
|
||||
.await
|
||||
.context("Failed to load files")?;
|
||||
let mut texts = vec![];
|
||||
if !raw_text.is_empty() {
|
||||
texts.push(raw_text.to_string());
|
||||
};
|
||||
if with_last_reply {
|
||||
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
||||
if !output.is_empty() {
|
||||
last_reply = Some(output.clone())
|
||||
} else if let Some(v) = input.last_reply.as_ref() {
|
||||
last_reply = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = last_reply.clone() {
|
||||
texts.push(format!("\n{v}"));
|
||||
}
|
||||
}
|
||||
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
|
||||
bail!("No last reply found");
|
||||
}
|
||||
}
|
||||
let documents_len = documents.len();
|
||||
for (kind, path, contents) in documents {
|
||||
if documents_len == 1 && raw_text.is_empty() {
|
||||
texts.push(format!("\n{contents}"));
|
||||
} else {
|
||||
texts.push(format!(
|
||||
"\n============ {kind}: {path} ============\n{contents}"
|
||||
));
|
||||
}
|
||||
}
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
text: texts.join("\n"),
|
||||
raw: (raw_text.to_string(), raw_paths),
|
||||
patched_text: None,
|
||||
last_reply,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias,
|
||||
data_urls,
|
||||
tool_calls: Default::default(),
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_files_with_spinner(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
abortable_run_with_spinner(
|
||||
Input::from_files(config, raw_text, paths, role),
|
||||
"Loading files",
|
||||
abort_signal,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.text.is_empty() && self.medias.is_empty()
|
||||
}
|
||||
|
||||
pub fn data_urls(&self) -> HashMap<String, String> {
|
||||
self.data_urls.clone()
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
pub fn text(&self) -> String {
|
||||
match self.patched_text.clone() {
|
||||
Some(text) => text,
|
||||
None => self.text.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_patch(&mut self) {
|
||||
self.patched_text = None;
|
||||
}
|
||||
|
||||
pub fn set_text(&mut self, text: String) {
|
||||
self.text = text;
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> bool {
|
||||
self.config.read().stream && !self.role().model().no_stream()
|
||||
}
|
||||
|
||||
pub fn continue_output(&self) -> Option<&str> {
|
||||
self.continue_output.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_continue_output(&mut self, output: &str) {
|
||||
let output = match &self.continue_output {
|
||||
Some(v) => format!("{v}{output}"),
|
||||
None => output.to_string(),
|
||||
};
|
||||
self.continue_output = Some(output);
|
||||
}
|
||||
|
||||
pub fn regenerate(&self) -> bool {
|
||||
self.regenerate
|
||||
}
|
||||
|
||||
pub fn set_regenerate(&mut self) {
|
||||
let role = self.config.read().extract_role();
|
||||
if role.name() == self.role().name() {
|
||||
self.role = role;
|
||||
}
|
||||
self.regenerate = true;
|
||||
self.tool_calls = None;
|
||||
}
|
||||
|
||||
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
|
||||
if self.text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rag = self.config.read().rag.clone();
|
||||
if let Some(rag) = rag {
|
||||
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
||||
self.patched_text = Some(result);
|
||||
self.rag_name = Some(rag.name().to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rag_name(&self) -> Option<&str> {
|
||||
self.rag_name.as_deref()
|
||||
}
|
||||
|
||||
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
|
||||
match self.tool_calls.as_mut() {
|
||||
Some(exist_tool_results) => {
|
||||
exist_tool_results.merge(tool_results, output);
|
||||
}
|
||||
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.config, Some(self.role().model().clone()))
|
||||
}
|
||||
|
||||
pub async fn fetch_chat_text(&self) -> Result<String> {
|
||||
let client = self.create_client()?;
|
||||
let text = client.chat_completions(self.clone()).await?.text;
|
||||
let text = strip_think_tag(&text).to_string();
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
pub fn prepare_completion_data(
|
||||
&self,
|
||||
model: &Model,
|
||||
stream: bool,
|
||||
) -> Result<ChatCompletionsData> {
|
||||
let mut messages = self.build_messages()?;
|
||||
patch_messages(&mut messages, model);
|
||||
model.guard_max_input_tokens(&messages)?;
|
||||
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
||||
let functions = self.config.read().select_functions(self.role());
|
||||
if let Some(vec) = &functions {
|
||||
for def in vec {
|
||||
debug!("Function definition: {:?}", def.name);
|
||||
}
|
||||
}
|
||||
Ok(ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
||||
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.build_messages(self)
|
||||
} else {
|
||||
self.role().build_messages(self)
|
||||
};
|
||||
if let Some(tool_calls) = &self.tool_calls {
|
||||
messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::ToolCalls(tool_calls.clone()),
|
||||
))
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self) -> String {
|
||||
if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.echo_messages(self)
|
||||
} else {
|
||||
self.role().echo_messages(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
|
||||
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
|
||||
if self.with_session {
|
||||
session.as_ref()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
|
||||
if self.with_session {
|
||||
session.as_mut()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_agent(&self) -> bool {
|
||||
self.with_agent
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> String {
|
||||
let text: String = self
|
||||
.text
|
||||
.trim()
|
||||
.chars()
|
||||
.map(|c| if c.is_control() { ' ' } else { c })
|
||||
.collect();
|
||||
if text.width_cjk() > SUMMARY_MAX_WIDTH {
|
||||
let mut sum_width = 0;
|
||||
let mut chars = vec![];
|
||||
for c in text.chars() {
|
||||
sum_width += c.width_cjk().unwrap_or(1);
|
||||
if sum_width > SUMMARY_MAX_WIDTH - 3 {
|
||||
chars.extend(['.', '.', '.']);
|
||||
break;
|
||||
}
|
||||
chars.push(c);
|
||||
}
|
||||
chars.into_iter().collect()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> String {
|
||||
let (text, files) = &self.raw;
|
||||
let mut segments = files.to_vec();
|
||||
if !segments.is_empty() {
|
||||
segments.insert(0, ".file".into());
|
||||
}
|
||||
if !text.is_empty() {
|
||||
if !segments.is_empty() {
|
||||
segments.push("--".into());
|
||||
}
|
||||
segments.push(text.clone());
|
||||
}
|
||||
segments.join(" ")
|
||||
}
|
||||
|
||||
pub fn render(&self) -> String {
|
||||
let text = self.text();
|
||||
if self.medias.is_empty() {
|
||||
return text;
|
||||
}
|
||||
let tail_text = if text.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" -- {text}")
|
||||
};
|
||||
let files: Vec<String> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| resolve_data_url(&self.data_urls, url))
|
||||
.collect();
|
||||
format!(".file {}{}", files.join(" "), tail_text)
|
||||
}
|
||||
|
||||
pub fn message_content(&self) -> MessageContent {
|
||||
if self.medias.is_empty() {
|
||||
MessageContent::Text(self.text())
|
||||
} else {
|
||||
let mut list: Vec<MessageContentPart> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
})
|
||||
.collect();
|
||||
if !self.text.is_empty() {
|
||||
list.insert(0, MessageContentPart::Text { text: self.text() });
|
||||
}
|
||||
MessageContent::Array(list)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
||||
match role {
|
||||
Some(v) => (v, false, false),
|
||||
None => (
|
||||
config.extract_role(),
|
||||
config.session.is_some(),
|
||||
config.agent.is_some(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type ResolvePathsOutput = (
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
bool,
|
||||
);
|
||||
|
||||
fn resolve_paths(
|
||||
loaders: &HashMap<String, String>,
|
||||
paths: Vec<String>,
|
||||
) -> Result<ResolvePathsOutput> {
|
||||
let mut raw_paths = IndexSet::new();
|
||||
let mut local_paths = IndexSet::new();
|
||||
let mut remote_urls = IndexSet::new();
|
||||
let mut external_cmds = IndexSet::new();
|
||||
let mut protocol_paths = IndexSet::new();
|
||||
let mut with_last_reply = false;
|
||||
for path in paths {
|
||||
if path == "%%" {
|
||||
with_last_reply = true;
|
||||
raw_paths.insert(path);
|
||||
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
|
||||
external_cmds.insert(path[1..path.len() - 1].to_string());
|
||||
raw_paths.insert(path);
|
||||
} else if is_url(&path) {
|
||||
if path.strip_suffix("**").is_some() {
|
||||
bail!("Invalid website '{path}'");
|
||||
}
|
||||
remote_urls.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else if is_loader_protocol(loaders, &path) {
|
||||
protocol_paths.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else {
|
||||
let resolved_path = resolve_home_dir(&path);
|
||||
let absolute_path = to_absolute_path(&resolved_path)
|
||||
.with_context(|| format!("Invalid path '{path}'"))?;
|
||||
local_paths.insert(resolved_path);
|
||||
raw_paths.insert(absolute_path);
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
raw_paths.into_iter().collect(),
|
||||
local_paths.into_iter().collect(),
|
||||
remote_urls.into_iter().collect(),
|
||||
external_cmds.into_iter().collect(),
|
||||
protocol_paths.into_iter().collect(),
|
||||
with_last_reply,
|
||||
))
|
||||
}
|
||||
|
||||
async fn load_documents(
|
||||
loaders: &HashMap<String, String>,
|
||||
local_paths: Vec<String>,
|
||||
remote_urls: Vec<String>,
|
||||
external_cmds: Vec<String>,
|
||||
protocol_paths: Vec<String>,
|
||||
) -> Result<(
|
||||
Vec<(&'static str, String, String)>,
|
||||
Vec<String>,
|
||||
HashMap<String, String>,
|
||||
)> {
|
||||
let mut files = vec![];
|
||||
let mut medias = vec![];
|
||||
let mut data_urls = HashMap::new();
|
||||
|
||||
for cmd in external_cmds {
|
||||
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
|
||||
.stderr_to_stdout()
|
||||
.unchecked()
|
||||
.read()
|
||||
.unwrap_or_else(|err| err.to_string());
|
||||
files.push(("CMD", cmd, output));
|
||||
}
|
||||
|
||||
let local_files = expand_glob_paths(&local_paths, true).await?;
|
||||
for file_path in local_files {
|
||||
if is_image(&file_path) {
|
||||
let contents = read_media_to_data_url(&file_path)
|
||||
.with_context(|| format!("Unable to read media '{file_path}'"))?;
|
||||
data_urls.insert(sha256(&contents), file_path);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
let document = load_file(loaders, &file_path)
|
||||
.await
|
||||
.with_context(|| format!("Unable to read file '{file_path}'"))?;
|
||||
files.push(("FILE", file_path, document.contents));
|
||||
}
|
||||
}
|
||||
|
||||
for file_url in remote_urls {
|
||||
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load url '{file_url}'"))?;
|
||||
if extension == MEDIA_URL_EXTENSION {
|
||||
data_urls.insert(sha256(&contents), file_url);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
files.push(("URL", file_url, contents));
|
||||
}
|
||||
}
|
||||
|
||||
for protocol_path in protocol_paths {
|
||||
let documents = load_protocol_path(loaders, &protocol_path)
|
||||
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
|
||||
files.extend(
|
||||
documents
|
||||
.into_iter()
|
||||
.map(|document| ("FROM", document.path, document.contents)),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((files, medias, data_urls))
|
||||
}
|
||||
|
||||
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
|
||||
if data_url.starts_with("data:") {
|
||||
let hash = sha256(&data_url);
|
||||
if let Some(path) = data_urls.get(&hash) {
|
||||
return path.to_string();
|
||||
}
|
||||
data_url
|
||||
} else {
|
||||
data_url
|
||||
}
|
||||
}
|
||||
|
||||
fn is_image(path: &str) -> bool {
|
||||
get_patch_extension(path)
|
||||
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn read_media_to_data_url(image_path: &str) -> Result<String> {
|
||||
let extension = get_patch_extension(image_path).unwrap_or_default();
|
||||
let mime_type = match extension.as_str() {
|
||||
"png" => "image/png",
|
||||
"jpg" | "jpeg" => "image/jpeg",
|
||||
"webp" => "image/webp",
|
||||
"gif" => "image/gif",
|
||||
_ => bail!("Unexpected media type"),
|
||||
};
|
||||
let mut file = File::open(image_path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let encoded_image = base64_encode(buffer);
|
||||
let data_url = format!("data:{mime_type};base64,{encoded_image}");
|
||||
|
||||
Ok(data_url)
|
||||
}
|
||||
Reference in New Issue
Block a user