diff --git a/config.example.yaml b/config.example.yaml index c417264..501f38c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -16,7 +16,7 @@ vault_password_file: null # Path to a file containing the password for th function_calling: true # Enables or disables function calling (Globally). mapping_tools: # Alias for a tool or toolset fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write' -use_tools: null # Which tools to use by default. (e.g. 'fs,web_search') +use_tools: null # Which tools to use by default. (e.g. 'fs,web_search_loki') # ---- MCP Servers ---- mcp_servers: true # Enables or disables MCP servers (globally). @@ -85,7 +85,6 @@ right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' # ---- misc ---- -serve_addr: 127.0.0.1:8000 # Server listening address user_agent: null # Set User-Agent HTTP header, use `auto` for loki/ save_shell_history: true # Whether to save shell execution command to the history file # URL to sync model changes from diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 03d7697..649591a 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -61,9 +61,6 @@ pub struct Cli { /// Execute a macro #[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))] pub macro_name: Option, - /// Serve the LLM API and WebAPP - #[arg(long, value_name = "PORT|IP|IP:PORT")] - pub serve: Option>, /// Execute commands in natural language #[arg(short = 'e', long)] pub execute: bool, diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index cedd3a2..6305219 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -518,13 +518,7 @@ fn extract_chat_completions(data: &Value) -> Result { bail!("Invalid response data: {data}"); } - let output = ChatCompletionsOutput { - text, - tool_calls, - id: None, - input_tokens: data["usage"]["inputTokens"].as_u64(), - output_tokens: data["usage"]["outputTokens"].as_u64(), - }; + let output = ChatCompletionsOutput { text, tool_calls }; Ok(output) } diff --git a/src/client/claude.rs b/src/client/claude.rs index 12e2559..60e1f5f 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -2,10 +2,10 @@ use super::*; use crate::utils::strip_think_tag; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result, bail}; use reqwest::RequestBuilder; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; const API_BASE: &str = "https://api.anthropic.com/v1"; @@ -301,7 +301,7 @@ pub fn claude_build_chat_completions_body( } }) .collect(); - } + } Ok(body) } @@ -353,9 +353,6 @@ pub fn claude_extract_chat_completions(data: &Value) -> Result Result { if text.is_empty() && tool_calls.is_empty() { bail!("Invalid response data: {data}"); } - let output = ChatCompletionsOutput { - text, - tool_calls, - id: data["id"].as_str().map(|v| v.to_string()), - input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(), - output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(), - }; + let output = ChatCompletionsOutput { text, tool_calls }; Ok(output) } diff --git a/src/client/common.rs b/src/client/common.rs index e903815..8339a6a 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -21,7 +21,7 @@ use std::sync::LazyLock; use std::time::Duration; use tokio::sync::mpsc::unbounded_channel; -const MODELS_YAML: &str = include_str!("../../models.yaml"); +pub const MODELS_YAML: &str = include_str!("../../models.yaml"); pub static ALL_PROVIDER_MODELS: LazyLock> = LazyLock::new(|| { Config::local_models_override() @@ -47,8 +47,6 @@ pub trait Client: Sync + Send { fn model(&self) -> &Model; - fn model_mut(&mut self) -> &mut Model; - fn build_client(&self) -> Result { let mut builder = ReqwestClient::builder(); let extra = self.extra_config(); @@ -291,9 +289,6 @@ pub struct ChatCompletionsData { pub struct ChatCompletionsOutput { pub text: String, pub tool_calls: Vec, - pub id: Option, - pub input_tokens: Option, - pub output_tokens: Option, } impl ChatCompletionsOutput { @@ -341,7 +336,6 @@ pub type RerankOutput = Vec; #[derive(Debug, Deserialize)] pub struct RerankResult { pub index: usize, - pub relevance_score: f64, } pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool); diff --git a/src/client/macros.rs b/src/client/macros.rs index 0b76974..72fc8be 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -159,10 +159,6 @@ macro_rules! client_common_fns { fn model(&self) -> &Model { &self.model } - - fn model_mut(&mut self) -> &mut Model { - &mut self.model - } }; } diff --git a/src/client/model.rs b/src/client/model.rs index 925e9c9..ab288af 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -118,14 +118,6 @@ impl Model { } } - pub fn data(&self) -> &ModelData { - &self.data - } - - pub fn data_mut(&mut self) -> &mut ModelData { - &mut self.data - } - pub fn description(&self) -> String { match self.model_type() { ModelType::Chat => { diff --git a/src/client/openai.rs b/src/client/openai.rs index b2432d4..f1dc2e0 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -389,13 +389,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result Result<()> { - // debug!("HandleCall: {:?}", call); self.tool_calls.push(call); Ok(()) } @@ -65,10 +64,6 @@ impl SseHandler { self.abort_signal.clone() } - pub fn tool_calls(&self) -> &[ToolCall] { - &self.tool_calls - } - pub fn take(self) -> (String, Vec) { let Self { buffer, tool_calls, .. diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 43e15bc..4cfb90d 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -296,13 +296,7 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result String { - self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into()) - } - - pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option)> { + pub fn log_config() -> Result<(LevelFilter, Option)> { let log_level = env::var(get_env_name("log_level")) .ok() .and_then(|v| v.parse().ok()) .unwrap_or(match cfg!(debug_assertions) { true => LevelFilter::Debug, - false => { - if is_serve { - LevelFilter::Off - } else { - LevelFilter::Info - } - } + false => LevelFilter::Info, }); let log_path = match env::var(get_env_name("log_path")) { Ok(v) => Some(PathBuf::from(v)), - Err(_) => match is_serve { - true => None, - false => Some(Config::log_path()), - }, + Err(_) => Some(Config::log_path()), }; Ok((log_level, log_path)) } @@ -744,7 +729,7 @@ impl Config { display_path(&self.vault_password_file()), ), ]; - if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) { + if let Ok((_, Some(log_path))) = Self::log_config() { items.push(("log_path", display_path(&log_path))); } let output = items @@ -1279,23 +1264,6 @@ impl Config { Ok(()) } - pub fn all_roles() -> Vec { - let mut roles: HashMap = Role::list_builtin_roles() - .iter() - .map(|v| (v.name().to_string(), v.clone())) - .collect(); - let names = Self::list_roles(false); - for name in names { - if let Ok(content) = read_to_string(Self::role_file(&name)) { - let role = Role::new(&name, &content); - roles.insert(name, role); - } - } - let mut roles: Vec<_> = roles.into_values().collect(); - roles.sort_unstable_by(|a, b| a.name().cmp(b.name())); - roles - } - pub fn list_roles(with_builtin: bool) -> Vec { let mut names = HashSet::new(); if let Ok(rd) = read_dir(Self::roles_dir()) { @@ -1921,7 +1889,6 @@ impl Config { let prelude = match self.working_mode { WorkingMode::Repl => self.repl_prelude.as_ref(), WorkingMode::Cmd => self.cmd_prelude.as_ref(), - WorkingMode::Serve => return Ok(()), }; let prelude = match prelude { Some(v) => { @@ -2835,10 +2802,6 @@ impl Config { if let Some(v) = read_env_value::(&get_env_name("right_prompt")) { self.right_prompt = v; } - - if let Some(v) = read_env_value::(&get_env_name("serve_addr")) { - self.serve_addr = v; - } if let Some(v) = read_env_value::(&get_env_name("user_agent")) { self.user_agent = v; } @@ -2947,7 +2910,6 @@ pub fn load_env_file() -> Result<()> { pub enum WorkingMode { Cmd, Repl, - Serve, } impl WorkingMode { @@ -2957,9 +2919,6 @@ impl WorkingMode { pub fn is_repl(&self) -> bool { *self == WorkingMode::Repl } - pub fn is_serve(&self) -> bool { - *self == WorkingMode::Serve - } } #[async_recursion::async_recursion] diff --git a/src/config/role.rs b/src/config/role.rs index ca496a4..5b3ef77 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -111,12 +111,6 @@ impl Role { .collect() } - pub fn list_builtin_roles() -> Vec { - RolesAsset::iter() - .filter_map(|v| Role::builtin(&v).ok()) - .collect() - } - pub fn has_args(&self) -> bool { self.name.contains('#') } diff --git a/src/main.rs b/src/main.rs index addbfbf..6df0480 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ mod function; mod rag; mod render; mod repl; -mod serve; #[macro_use] mod utils; mod mcp; @@ -58,9 +57,7 @@ async fn main() -> Result<()> { } let text = cli.text()?; - let working_mode = if cli.serve.is_some() { - WorkingMode::Serve - } else if text.is_none() && cli.file.is_empty() { + let working_mode = if text.is_none() && cli.file.is_empty() { WorkingMode::Repl } else { WorkingMode::Cmd @@ -80,7 +77,7 @@ async fn main() -> Result<()> { || cli.delete_secret.is_some() || cli.list_secrets; - let log_path = setup_logger(working_mode.is_serve())?; + let log_path = setup_logger()?; if vault_flags { return Vault::handle_vault_flags(cli, Config::init_bare()?); @@ -219,9 +216,6 @@ async fn run( println!("{info}"); return Ok(()); } - if let Some(addr) = cli.serve { - return serve::run(config, addr).await; - } let is_repl = config.read().working_mode.is_repl(); if cli.rebuild_rag { Config::rebuild_rag(&config, abort_signal.clone()).await?; @@ -429,22 +423,15 @@ async fn create_input( Ok(input) } -fn setup_logger(is_serve: bool) -> Result> { - let (log_level, log_path) = Config::log_config(is_serve)?; +fn setup_logger() -> Result> { + let (log_level, log_path) = Config::log_config()?; if log_level == LevelFilter::Off { return Ok(None); } let encoder = Box::new(PatternEncoder::new( "{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}", )); - let log_filter = match env::var(get_env_name("log_filter")) { - Ok(v) => Some(v), - Err(_) => match is_serve { - true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))), - false => None, - }, - }; - + let log_filter = env::var(get_env_name("log_filter")).ok(); match log_path.clone() { None => { let console_appender = ConsoleAppender::builder().encoder(encoder).build(); diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index 1264af4..c57f981 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -41,8 +41,8 @@ impl Prompt for ReplPrompt { PromptHistorySearchStatus::Passing => "", PromptHistorySearchStatus::Failing => "failing ", }; - // NOTE: magic strings, given there is logic on how these compose I am not sure if it - // is worth extracting in to static constant + // NOTE: magic strings, given there is logic on how these are composed, I'm unsure if it's + // worth extracting into a static constant Cow::Owned(format!( "({}reverse-search: {}) ", prefix, history_search.term diff --git a/src/serve.rs b/src/serve.rs deleted file mode 100644 index d6e0afe..0000000 --- a/src/serve.rs +++ /dev/null @@ -1,935 +0,0 @@ -use crate::{client::*, config::*, function::*, rag::*, utils::*}; - -use anyhow::{anyhow, bail, Result}; -use bytes::Bytes; -use chrono::{Timelike, Utc}; -use futures_util::StreamExt; -use http::{Method, Response, StatusCode}; -use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; -use hyper::{ - body::{Frame, Incoming}, - service::service_fn, -}; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use parking_lot::RwLock; -use serde::Deserialize; -use serde_json::{json, Value}; -use std::{ - convert::Infallible, - net::IpAddr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; -use tokio::{ - net::TcpListener, - sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - oneshot, - }, -}; -use tokio_graceful::Shutdown; -use tokio_stream::wrappers::UnboundedReceiverStream; - -const DEFAULT_MODEL_NAME: &str = "default"; -const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html"); -const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html"); - -type AppResponse = Response>; - -pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> { - let addr = match addr { - Some(addr) => { - if let Ok(port) = addr.parse::() { - format!("127.0.0.1:{port}") - } else if let Ok(ip) = addr.parse::() { - format!("{ip}:8000") - } else { - addr - } - } - None => config.read().serve_addr(), - }; - let server = Arc::new(Server::new(&config)); - let listener = TcpListener::bind(&addr).await?; - let stop_server = server.run(listener).await?; - println!("Chat Completions API: http://{addr}/v1/chat/completions"); - println!("Embeddings API: http://{addr}/v1/embeddings"); - println!("Rerank API: http://{addr}/v1/rerank"); - println!("LLM Playground: http://{addr}/playground"); - println!("LLM Arena: http://{addr}/arena?num=2"); - shutdown_signal().await; - let _ = stop_server.send(()); - Ok(()) -} - -struct Server { - config: Config, - models: Vec, - roles: Vec, - rags: Vec, -} - -impl Server { - fn new(config: &GlobalConfig) -> Self { - let mut config = config.read().clone(); - config.functions = Functions::default(); - let mut models = list_all_models(&config); - let mut default_model = config.model.clone(); - default_model.data_mut().name = DEFAULT_MODEL_NAME.into(); - models.insert(0, &default_model); - let models: Vec = models - .into_iter() - .enumerate() - .map(|(i, model)| { - let id = if i == 0 { - DEFAULT_MODEL_NAME.into() - } else { - model.id() - }; - let mut value = json!(model.data()); - if let Some(value_obj) = value.as_object_mut() { - value_obj.insert("id".into(), id.into()); - value_obj.insert("object".into(), "model".into()); - value_obj.insert("owned_by".into(), model.client_name().into()); - value_obj.remove("name"); - } - value - }) - .collect(); - Self { - config, - models, - roles: Config::all_roles(), - rags: Config::list_rags(), - } - } - - async fn run(self: Arc, listener: TcpListener) -> Result> { - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); - let guard = shutdown.guard_weak(); - - loop { - tokio::select! { - res = listener.accept() => { - let Ok((cnx, _)) = res else { - continue; - }; - - let stream = TokioIo::new(cnx); - let server = self.clone(); - shutdown.spawn_task(async move { - let hyper_service = service_fn(move |request: hyper::Request| { - server.clone().handle(request) - }); - let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(stream, hyper_service) - .await; - }); - } - _ = guard.cancelled() => { - break; - } - } - } - }); - Ok(tx) - } - - async fn handle( - self: Arc, - req: hyper::Request, - ) -> std::result::Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let path = uri.path(); - - if method == Method::OPTIONS { - let mut res = Response::default(); - *res.status_mut() = StatusCode::NO_CONTENT; - set_cors_header(&mut res); - return Ok(res); - } - - let mut status = StatusCode::OK; - let res = if path == "/v1/chat/completions" { - self.chat_completions(req).await - } else if path == "/v1/embeddings" { - self.embeddings(req).await - } else if path == "/v1/rerank" { - self.rerank(req).await - } else if path == "/v1/models" { - self.list_models() - } else if path == "/v1/roles" { - self.list_roles() - } else if path == "/v1/rags" { - self.list_rags() - } else if path == "/v1/rags/search" { - self.search_rag(req).await - } else if path == "/playground" || path == "/playground.html" { - self.playground_page() - } else if path == "/arena" || path == "/arena.html" { - self.arena_page() - } else { - status = StatusCode::NOT_FOUND; - Err(anyhow!("Not Found")) - }; - let mut res = match res { - Ok(res) => { - info!("{method} {uri} {}", status.as_u16()); - res - } - Err(err) => { - if status == StatusCode::OK { - status = StatusCode::BAD_REQUEST; - } - error!("{method} {uri} {} {err}", status.as_u16()); - ret_err(err) - } - }; - *res.status_mut() = status; - set_cors_header(&mut res); - Ok(res) - } - - fn playground_page(&self) -> Result { - let res = Response::builder() - .header("Content-Type", "text/html; charset=utf-8") - .body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?; - Ok(res) - } - - fn arena_page(&self) -> Result { - let res = Response::builder() - .header("Content-Type", "text/html; charset=utf-8") - .body(Full::new(Bytes::from(ARENA_HTML)).boxed())?; - Ok(res) - } - - fn list_models(&self) -> Result { - let data = json!({ "data": self.models }); - let res = Response::builder() - .header("Content-Type", "application/json; charset=utf-8") - .body(Full::new(Bytes::from(data.to_string())).boxed())?; - Ok(res) - } - - fn list_roles(&self) -> Result { - let data = json!({ "data": self.roles }); - let res = Response::builder() - .header("Content-Type", "application/json; charset=utf-8") - .body(Full::new(Bytes::from(data.to_string())).boxed())?; - Ok(res) - } - - fn list_rags(&self) -> Result { - let data = json!({ "data": self.rags }); - let res = Response::builder() - .header("Content-Type", "application/json; charset=utf-8") - .body(Full::new(Bytes::from(data.to_string())).boxed())?; - Ok(res) - } - - async fn search_rag(&self, req: hyper::Request) -> Result { - let req_body = req.collect().await?.to_bytes(); - let req_body: Value = serde_json::from_slice(&req_body) - .map_err(|err| anyhow!("Invalid request json, {err}"))?; - - debug!("search rag request: {req_body}"); - let SearchRagReqBody { name, input } = serde_json::from_value(req_body) - .map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let config = Arc::new(RwLock::new(self.config.clone())); - - let abort_signal = create_abort_signal(); - - let rag_path = config.read().rag_file(&name); - let rag = Rag::load(&config, &name, &rag_path)?; - - let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?; - - let data = json!({ "data": rag_result }); - let res = Response::builder() - .header("Content-Type", "application/json; charset=utf-8") - .body(Full::new(Bytes::from(data.to_string())).boxed())?; - Ok(res) - } - - async fn chat_completions(&self, req: hyper::Request) -> Result { - let req_body = req.collect().await?.to_bytes(); - let req_body: Value = serde_json::from_slice(&req_body) - .map_err(|err| anyhow!("Invalid request json, {err}"))?; - - debug!("chat completions request: {req_body}"); - let req_body = serde_json::from_value(req_body) - .map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let ChatCompletionsReqBody { - model, - messages, - temperature, - top_p, - max_tokens, - stream, - tools, - } = req_body; - - let mut messages = - parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let config = self.config.clone(); - - let default_model = config.model.clone(); - - let config = Arc::new(RwLock::new(config)); - - let (model_name, change) = if model == DEFAULT_MODEL_NAME { - (default_model.id(), true) - } else if default_model.id() == model { - (model, false) - } else { - (model, true) - }; - - if change { - config.write().set_model(&model_name)?; - } - - let mut client = init_client(&config, None)?; - if max_tokens.is_some() { - client.model_mut().set_max_tokens(max_tokens, true); - } - let abort_signal = create_abort_signal(); - let http_client = client.build_client()?; - - let completion_id = generate_completion_id(); - let created = Utc::now().timestamp(); - - patch_messages(&mut messages, client.model()); - - let data: ChatCompletionsData = ChatCompletionsData { - messages, - temperature, - top_p, - functions, - stream, - }; - - if stream { - let (tx, mut rx) = unbounded_channel(); - tokio::spawn(async move { - let is_first = Arc::new(AtomicBool::new(true)); - let (sse_tx, sse_rx) = unbounded_channel(); - let mut handler = SseHandler::new(sse_tx, abort_signal); - async fn map_event( - mut sse_rx: UnboundedReceiver, - tx: &UnboundedSender, - is_first: Arc, - ) { - while let Some(reply_event) = sse_rx.recv().await { - if is_first.load(Ordering::SeqCst) { - let _ = tx.send(ResEvent::First(None)); - is_first.store(false, Ordering::SeqCst) - } - match reply_event { - SseEvent::Text(text) => { - let _ = tx.send(ResEvent::Text(text)); - } - SseEvent::Done => { - let _ = tx.send(ResEvent::Done); - sse_rx.close(); - } - } - } - } - async fn chat_completions( - client: &dyn Client, - http_client: &reqwest::Client, - handler: &mut SseHandler, - mut data: ChatCompletionsData, - tx: &UnboundedSender, - is_first: Arc, - ) { - if client.model().no_stream() { - data.stream = false; - let ret = client.chat_completions_inner(http_client, data).await; - match ret { - Ok(output) => { - let ChatCompletionsOutput { - text, tool_calls, .. - } = output; - let _ = tx.send(ResEvent::First(None)); - is_first.store(false, Ordering::SeqCst); - let _ = tx.send(ResEvent::Text(text)); - if !tool_calls.is_empty() { - let _ = tx.send(ResEvent::ToolCalls(tool_calls)); - } - } - Err(err) => { - let _ = tx.send(ResEvent::First(Some(format!("{err:?}")))); - is_first.store(false, Ordering::SeqCst) - } - }; - } else { - let ret = client - .chat_completions_streaming_inner(http_client, handler, data) - .await; - let first = match ret { - Ok(()) => None, - Err(err) => Some(format!("{err:?}")), - }; - if is_first.load(Ordering::SeqCst) { - let _ = tx.send(ResEvent::First(first)); - is_first.store(false, Ordering::SeqCst) - } - let tool_calls = handler.tool_calls().to_vec(); - if !tool_calls.is_empty() { - let _ = tx.send(ResEvent::ToolCalls(tool_calls)); - } - } - handler.done(); - } - tokio::join!( - map_event(sse_rx, &tx, is_first.clone()), - chat_completions( - client.as_ref(), - &http_client, - &mut handler, - data, - &tx, - is_first - ), - ); - }); - - let first_event = rx.recv().await; - - if let Some(ResEvent::First(Some(err))) = first_event { - bail!("{err}"); - } - - let shared: Arc<(String, String, i64, AtomicBool)> = - Arc::new((completion_id, model_name, created, AtomicBool::new(false))); - let stream = UnboundedReceiverStream::new(rx); - let stream = stream.filter_map(move |res_event| { - let shared = shared.clone(); - async move { - let (completion_id, model, created, has_tool_calls) = shared.as_ref(); - match res_event { - ResEvent::Text(text) => { - Some(Ok(create_text_frame(completion_id, model, *created, &text))) - } - ResEvent::ToolCalls(tool_calls) => { - has_tool_calls.store(true, Ordering::SeqCst); - Some(Ok(create_tool_calls_frame( - completion_id, - model, - *created, - &tool_calls, - ))) - } - ResEvent::Done => Some(Ok(create_done_frame( - completion_id, - model, - *created, - has_tool_calls.load(Ordering::SeqCst), - ))), - _ => None, - } - } - }); - let res = Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "text/event-stream") - .header("Cache-Control", "no-cache") - .header("Connection", "keep-alive") - .body(BodyExt::boxed(StreamBody::new(stream)))?; - Ok(res) - } else { - let output = client.chat_completions_inner(&http_client, data).await?; - let res = Response::builder() - .header("Content-Type", "application/json") - .body( - Full::new(ret_non_stream( - &completion_id, - &model_name, - created, - &output, - )) - .boxed(), - )?; - Ok(res) - } - } - - async fn embeddings(&self, req: hyper::Request) -> Result { - let req_body = req.collect().await?.to_bytes(); - let req_body: Value = serde_json::from_slice(&req_body) - .map_err(|err| anyhow!("Invalid request json, {err}"))?; - - debug!("embeddings request: {req_body}"); - let req_body = serde_json::from_value(req_body) - .map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let EmbeddingsReqBody { - input, - model: embedding_model_id, - } = req_body; - - let config = Arc::new(RwLock::new(self.config.clone())); - - let embedding_model = - Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; - - let texts = match input { - EmbeddingsReqBodyInput::Single(v) => vec![v], - EmbeddingsReqBodyInput::Multiple(v) => v, - }; - let client = init_client(&config, Some(embedding_model))?; - let data = client - .embeddings(&EmbeddingsData { - query: false, - texts, - }) - .await?; - let data: Vec<_> = data - .into_iter() - .enumerate() - .map(|(i, v)| { - json!({ - "object": "embedding", - "embedding": v, - "index": i, - }) - }) - .collect(); - let output = json!({ - "object": "list", - "data": data, - "model": embedding_model_id, - "usage": { - "prompt_tokens": 0, - "total_tokens": 0, - } - }); - let res = Response::builder() - .header("Content-Type", "application/json") - .body(Full::new(Bytes::from(output.to_string())).boxed())?; - Ok(res) - } - - async fn rerank(&self, req: hyper::Request) -> Result { - let req_body = req.collect().await?.to_bytes(); - let req_body: Value = serde_json::from_slice(&req_body) - .map_err(|err| anyhow!("Invalid request json, {err}"))?; - - debug!("rerank request: {req_body}"); - let req_body = serde_json::from_value(req_body) - .map_err(|err| anyhow!("Invalid request body, {err}"))?; - - let RerankReqBody { - model: reranker_model_id, - documents, - query, - top_n, - } = req_body; - - let top_n = top_n.unwrap_or(documents.len()); - - let config = Arc::new(RwLock::new(self.config.clone())); - - let reranker_model = - Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?; - - let client = init_client(&config, Some(reranker_model))?; - let data = client - .rerank(&RerankData { - query, - documents: documents.clone(), - top_n, - }) - .await?; - - let results: Vec<_> = data - .into_iter() - .map(|v| { - json!({ - "index": v.index, - "relevance_score": v.relevance_score, - "document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(), - }) - }) - .collect(); - let output = json!({ - "id": uuid::Uuid::new_v4().to_string(), - "results": results, - }); - let res = Response::builder() - .header("Content-Type", "application/json") - .body(Full::new(Bytes::from(output.to_string())).boxed())?; - Ok(res) - } -} - -#[derive(Debug, Deserialize)] -struct SearchRagReqBody { - name: String, - input: String, -} - -#[derive(Debug, Deserialize)] -struct ChatCompletionsReqBody { - model: String, - messages: Vec, - temperature: Option, - top_p: Option, - max_tokens: Option, - #[serde(default)] - stream: bool, - tools: Option>, -} - -#[derive(Debug, Deserialize)] -struct EmbeddingsReqBody { - input: EmbeddingsReqBodyInput, - model: String, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -enum EmbeddingsReqBodyInput { - Single(String), - Multiple(Vec), -} - -#[derive(Debug, Deserialize)] -struct RerankReqBody { - documents: Vec, - query: String, - model: String, - top_n: Option, -} - -#[derive(Debug)] -enum ResEvent { - First(Option), - Text(String), - ToolCalls(Vec), - Done, -} - -async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("Failed to install CTRL+C signal handler") -} - -fn generate_completion_id() -> String { - let random_id = Utc::now().nanosecond(); - format!("chatcmpl-{random_id}") -} - -fn set_cors_header(res: &mut AppResponse) { - res.headers_mut().insert( - hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, - hyper::header::HeaderValue::from_static("*"), - ); - res.headers_mut().insert( - hyper::header::ACCESS_CONTROL_ALLOW_METHODS, - hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"), - ); - res.headers_mut().insert( - hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, - hyper::header::HeaderValue::from_static("Content-Type,Authorization"), - ); -} - -fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame { - let delta = if content.is_empty() { - json!({ "role": "assistant", "content": content }) - } else { - json!({ "content": content }) - }; - let choice = json!({ - "index": 0, - "delta": delta, - "finish_reason": null, - }); - let value = build_chat_completion_chunk_json(id, model, created, &choice); - Frame::data(Bytes::from(format!("data: {value}\n\n"))) -} - -fn create_tool_calls_frame( - id: &str, - model: &str, - created: i64, - tool_calls: &[ToolCall], -) -> Frame { - let chunks = tool_calls - .iter() - .enumerate() - .flat_map(|(i, call)| { - let choice1 = json!({ - "index": 0, - "delta": { - "role": "assistant", - "content": null, - "tool_calls": [ - { - "index": i, - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": "" - } - } - ] - }, - "finish_reason": null - }); - let choice2 = json!({ - "index": 0, - "delta": { - "tool_calls": [ - { - "index": i, - "function": { - "arguments": call.arguments.to_string(), - } - } - ] - }, - "finish_reason": null - }); - vec![ - build_chat_completion_chunk_json(id, model, created, &choice1), - build_chat_completion_chunk_json(id, model, created, &choice2), - ] - }) - .map(|v| format!("data: {v}\n\n")) - .collect::>() - .join(""); - Frame::data(Bytes::from(chunks)) -} - -fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame { - let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" }; - let choice = json!({ - "index": 0, - "delta": {}, - "finish_reason": finish_reason, - }); - let value = build_chat_completion_chunk_json(id, model, created, &choice); - Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n"))) -} - -fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value { - json!({ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [choice], - }) -} - -fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes { - let id = output.id.as_deref().unwrap_or(id); - let input_tokens = output.input_tokens.unwrap_or_default(); - let output_tokens = output.output_tokens.unwrap_or_default(); - let total_tokens = input_tokens + output_tokens; - let choice = if output.tool_calls.is_empty() { - json!({ - "index": 0, - "message": { - "role": "assistant", - "content": output.text, - }, - "logprobs": null, - "finish_reason": "stop", - }) - } else { - let content = if output.text.is_empty() { - Value::Null - } else { - output.text.clone().into() - }; - let tool_calls: Vec<_> = output - .tool_calls - .iter() - .map(|call| { - json!({ - "id": call.id, - "type": "function", - "function": { - "name": call.name, - "arguments": call.arguments.to_string(), - } - }) - }) - .collect(); - json!({ - "index": 0, - "message": { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - }, - "logprobs": null, - "finish_reason": "tool_calls", - }) - }; - let res_body = json!({ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": [choice], - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": total_tokens, - }, - }); - Bytes::from(res_body.to_string()) -} - -fn ret_err(err: T) -> AppResponse { - let data = json!({ - "error": { - "message": err.to_string(), - "type": "invalid_request_error", - }, - }); - Response::builder() - .header("Content-Type", "application/json") - .body(Full::new(Bytes::from(data.to_string())).boxed()) - .unwrap() -} - -fn parse_messages(message: Vec) -> Result> { - let mut output = vec![]; - let mut tool_results = None; - for (i, message) in message.into_iter().enumerate() { - let err = || anyhow!("Failed to parse '.messages[{i}]'"); - let role = message["role"].as_str().ok_or_else(err)?; - let content = match message.get("content") { - Some(value) => { - if let Some(value) = value.as_str() { - MessageContent::Text(value.to_string()) - } else if value.is_array() { - let value = serde_json::from_value(value.clone()).map_err(|_| err())?; - MessageContent::Array(value) - } else if value.is_null() { - MessageContent::Text(String::new()) - } else { - return Err(err()); - } - } - None => MessageContent::Text(String::new()), - }; - match role { - "system" | "user" => { - let role = match role { - "system" => MessageRole::System, - "user" => MessageRole::User, - _ => unreachable!(), - }; - output.push(Message::new(role, content)) - } - "assistant" => { - let role = MessageRole::Assistant; - match message["tool_calls"].as_array() { - Some(tool_calls) => { - if tool_results.is_some() { - return Err(err()); - } - let mut list = vec![]; - for tool_call in tool_calls { - if let (id, Some(name), Some(arguments)) = ( - tool_call["id"].as_str().map(|v| v.to_string()), - tool_call["function"]["name"].as_str(), - tool_call["function"]["arguments"].as_str(), - ) { - let arguments = - serde_json::from_str(arguments).map_err(|_| err())?; - list.push((id, name.to_string(), arguments)); - } else { - return Err(err()); - } - } - tool_results = Some((content.to_text(), list, vec![])); - } - None => output.push(Message::new(role, content)), - } - } - "tool" => match tool_results.take() { - Some((text, tool_calls, mut tool_values)) => { - let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string()); - let content = content.to_text(); - let value: Value = serde_json::from_str(&content) - .ok() - .unwrap_or_else(|| content.into()); - - tool_values.push((value, tool_call_id)); - - if tool_calls.len() == tool_values.len() { - let mut list = vec![]; - for ((id, name, arguments), (value, tool_call_id)) in - tool_calls.into_iter().zip(tool_values.into_iter()) - { - if id != tool_call_id { - return Err(err()); - } - list.push(ToolResult::new(ToolCall::new(name, arguments, id), value)) - } - output.push(Message::new( - MessageRole::Assistant, - MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)), - )); - tool_results = None; - } else { - tool_results = Some((text, tool_calls, tool_values)); - } - } - None => return Err(err()), - }, - _ => { - return Err(err()); - } - } - } - - if tool_results.is_some() { - bail!("Invalid messages"); - } - - Ok(output) -} - -fn parse_tools(tools: Option>) -> Result>> { - let tools = match tools { - Some(v) => v, - None => return Ok(None), - }; - let mut functions = vec![]; - for (i, tool) in tools.into_iter().enumerate() { - if let (Some("function"), Some(function)) = ( - tool["type"].as_str(), - tool["function"] - .as_object() - .and_then(|v| serde_json::from_value(json!(v)).ok()), - ) { - functions.push(function); - } else { - bail!("Failed to parse '.tools[{i}]'") - } - } - Ok(Some(functions)) -} diff --git a/src/utils/render_prompt.rs b/src/utils/render_prompt.rs index 12661fa..921929a 100644 --- a/src/utils/render_prompt.rs +++ b/src/utils/render_prompt.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; /// Render REPL prompt /// -/// The template comprises plain text and `{...}`. +/// The template comprises of plain text and `{...}`. /// /// The syntax of `{...}`: /// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`