testing
This commit is contained in:
+57
-28
@@ -9,7 +9,7 @@ use crate::utils::{AbortSignal, base64_encode, is_loader_protocol, sha256};
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use indexmap::IndexSet;
|
||||
use std::{collections::HashMap, fs::File, io::Read};
|
||||
use std::{collections::HashMap, fs::File, io::Read, sync::Arc};
|
||||
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
||||
|
||||
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
||||
@@ -17,7 +17,11 @@ const SUMMARY_MAX_WIDTH: usize = 80;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Input {
|
||||
config: GlobalConfig,
|
||||
app_config: Arc<AppConfig>,
|
||||
stream_enabled: bool,
|
||||
session: Option<Session>,
|
||||
rag: Option<Arc<crate::rag::Rag>>,
|
||||
functions: Option<Vec<FunctionDeclaration>>,
|
||||
text: String,
|
||||
raw: (String, Vec<String>),
|
||||
patched_text: Option<String>,
|
||||
@@ -34,10 +38,15 @@ pub struct Input {
|
||||
}
|
||||
|
||||
impl Input {
|
||||
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
pub fn from_str(ctx: &RequestContext, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(ctx, role);
|
||||
let captured = capture_input_config(ctx, &role);
|
||||
Self {
|
||||
config: config.clone(),
|
||||
app_config: Arc::clone(&ctx.app.config),
|
||||
stream_enabled: captured.stream_enabled,
|
||||
session: captured.session,
|
||||
rag: captured.rag,
|
||||
functions: captured.functions,
|
||||
text: text.to_string(),
|
||||
raw: (text.to_string(), vec![]),
|
||||
patched_text: None,
|
||||
@@ -55,12 +64,12 @@ impl Input {
|
||||
}
|
||||
|
||||
pub async fn from_files(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
) -> Result<Self> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = ctx.app.config.document_loaders.clone();
|
||||
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
||||
resolve_paths(&loaders, paths)?;
|
||||
let mut last_reply = None;
|
||||
@@ -78,7 +87,7 @@ impl Input {
|
||||
texts.push(raw_text.to_string());
|
||||
};
|
||||
if with_last_reply {
|
||||
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
||||
if let Some(LastMessage { input, output, .. }) = ctx.last_message.as_ref() {
|
||||
if !output.is_empty() {
|
||||
last_reply = Some(output.clone())
|
||||
} else if let Some(v) = input.last_reply.as_ref() {
|
||||
@@ -102,9 +111,14 @@ impl Input {
|
||||
));
|
||||
}
|
||||
}
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
let (role, with_session, with_agent) = resolve_role(ctx, role);
|
||||
let captured = capture_input_config(ctx, &role);
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
app_config: Arc::clone(&ctx.app.config),
|
||||
stream_enabled: captured.stream_enabled,
|
||||
session: captured.session,
|
||||
rag: captured.rag,
|
||||
functions: captured.functions,
|
||||
text: texts.join("\n"),
|
||||
raw: (raw_text.to_string(), raw_paths),
|
||||
patched_text: None,
|
||||
@@ -122,14 +136,14 @@ impl Input {
|
||||
}
|
||||
|
||||
pub async fn from_files_with_spinner(
|
||||
config: &GlobalConfig,
|
||||
ctx: &RequestContext,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
abortable_run_with_spinner(
|
||||
Input::from_files(config, raw_text, paths, role),
|
||||
Input::from_files(ctx, raw_text, paths, role),
|
||||
"Loading files",
|
||||
abort_signal,
|
||||
)
|
||||
@@ -164,7 +178,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> bool {
|
||||
self.config.read().stream && !self.role().model().no_stream()
|
||||
self.stream_enabled && !self.role().model().no_stream()
|
||||
}
|
||||
|
||||
pub fn continue_output(&self) -> Option<&str> {
|
||||
@@ -183,10 +197,9 @@ impl Input {
|
||||
self.regenerate
|
||||
}
|
||||
|
||||
pub fn set_regenerate(&mut self) {
|
||||
let role = self.config.read().extract_role();
|
||||
if role.name() == self.role().name() {
|
||||
self.role = role;
|
||||
pub fn set_regenerate(&mut self, current_role: Role) {
|
||||
if current_role.name() == self.role().name() {
|
||||
self.role = current_role;
|
||||
}
|
||||
self.regenerate = true;
|
||||
self.tool_calls = None;
|
||||
@@ -196,9 +209,9 @@ impl Input {
|
||||
if self.text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rag = self.config.read().rag.clone();
|
||||
if let Some(rag) = rag {
|
||||
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
||||
if let Some(rag) = &self.rag {
|
||||
let result =
|
||||
Config::search_rag(&self.app_config, rag, &self.text, abort_signal).await?;
|
||||
self.patched_text = Some(result);
|
||||
self.rag_name = Some(rag.name().to_string());
|
||||
}
|
||||
@@ -220,7 +233,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.config, Some(self.role().model().clone()))
|
||||
init_client(&self.app_config, self.role().model().clone())
|
||||
}
|
||||
|
||||
pub async fn fetch_chat_text(&self) -> Result<String> {
|
||||
@@ -240,7 +253,7 @@ impl Input {
|
||||
model.guard_max_input_tokens(&messages)?;
|
||||
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
||||
let functions = if model.supports_function_calling() {
|
||||
let fns = self.config.read().select_functions(self.role());
|
||||
let fns = self.functions.clone();
|
||||
if let Some(vec) = &fns {
|
||||
for def in vec {
|
||||
debug!("Function definition: {:?}", def.name);
|
||||
@@ -260,7 +273,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
||||
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
||||
let mut messages = if let Some(session) = self.session(&self.session) {
|
||||
session.build_messages(self)
|
||||
} else {
|
||||
self.role().build_messages(self)
|
||||
@@ -275,7 +288,7 @@ impl Input {
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self) -> String {
|
||||
if let Some(session) = self.session(&self.config.read().session) {
|
||||
if let Some(session) = self.session(&self.session) {
|
||||
session.echo_messages(self)
|
||||
} else {
|
||||
self.role().echo_messages(self)
|
||||
@@ -384,17 +397,33 @@ impl Input {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
||||
fn resolve_role(ctx: &RequestContext, role: Option<Role>) -> (Role, bool, bool) {
|
||||
match role {
|
||||
Some(v) => (v, false, false),
|
||||
None => (
|
||||
config.extract_role(),
|
||||
config.session.is_some(),
|
||||
config.agent.is_some(),
|
||||
ctx.extract_role(ctx.app.config.as_ref()),
|
||||
ctx.session.is_some(),
|
||||
ctx.agent.is_some(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
struct CapturedInputConfig {
|
||||
stream_enabled: bool,
|
||||
session: Option<Session>,
|
||||
rag: Option<Arc<crate::rag::Rag>>,
|
||||
functions: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
fn capture_input_config(ctx: &RequestContext, role: &Role) -> CapturedInputConfig {
|
||||
CapturedInputConfig {
|
||||
stream_enabled: ctx.app.config.stream,
|
||||
session: ctx.session.clone(),
|
||||
rag: ctx.rag.clone(),
|
||||
functions: ctx.select_functions(role),
|
||||
}
|
||||
}
|
||||
|
||||
type ResolvePathsOutput = (
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
|
||||
Reference in New Issue
Block a user