feat: Removed the server functionality

This commit is contained in:
2025-11-03 14:25:55 -07:00
parent b49a27f886
commit 474c5bc76f
17 changed files with 21 additions and 1070 deletions
+1 -2
View File
@@ -16,7 +16,7 @@ vault_password_file: null # Path to a file containing the password for th
function_calling: true # Enables or disables function calling (Globally). function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset mapping_tools: # Alias for a tool or toolset
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write' fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search') use_tools: null # Which tools to use by default. (e.g. 'fs,web_search_loki')
# ---- MCP Servers ---- # ---- MCP Servers ----
mcp_servers: true # Enables or disables MCP servers (globally). mcp_servers: true # Enables or disables MCP servers (globally).
@@ -85,7 +85,6 @@ right_prompt:
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}' '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
# ---- misc ---- # ---- misc ----
serve_addr: 127.0.0.1:8000 # Server listening address
user_agent: null # Set User-Agent HTTP header, use `auto` for loki/<current-version> user_agent: null # Set User-Agent HTTP header, use `auto` for loki/<current-version>
save_shell_history: true # Whether to save shell execution command to the history file save_shell_history: true # Whether to save shell execution command to the history file
# URL to sync model changes from # URL to sync model changes from
-3
View File
@@ -61,9 +61,6 @@ pub struct Cli {
/// Execute a macro /// Execute a macro
#[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))] #[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))]
pub macro_name: Option<String>, pub macro_name: Option<String>,
/// Serve the LLM API and WebAPP
#[arg(long, value_name = "PORT|IP|IP:PORT")]
pub serve: Option<Option<String>>,
/// Execute commands in natural language /// Execute commands in natural language
#[arg(short = 'e', long)] #[arg(short = 'e', long)]
pub execute: bool, pub execute: bool,
+1 -7
View File
@@ -518,13 +518,7 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
bail!("Invalid response data: {data}"); bail!("Invalid response data: {data}");
} }
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput { text, tool_calls };
text,
tool_calls,
id: None,
input_tokens: data["usage"]["inputTokens"].as_u64(),
output_tokens: data["usage"]["outputTokens"].as_u64(),
};
Ok(output) Ok(output)
} }
+2 -5
View File
@@ -2,10 +2,10 @@ use super::*;
use crate::utils::strip_think_tag; use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result}; use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{Value, json};
const API_BASE: &str = "https://api.anthropic.com/v1"; const API_BASE: &str = "https://api.anthropic.com/v1";
@@ -353,9 +353,6 @@ pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput {
text: text.to_string(), text: text.to_string(),
tool_calls, tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
}; };
Ok(output) Ok(output)
} }
+1 -7
View File
@@ -244,12 +244,6 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
if text.is_empty() && tool_calls.is_empty() { if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}"); bail!("Invalid response data: {data}");
} }
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput { text, tool_calls };
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
};
Ok(output) Ok(output)
} }
+1 -7
View File
@@ -21,7 +21,7 @@ use std::sync::LazyLock;
use std::time::Duration; use std::time::Duration;
use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::unbounded_channel;
const MODELS_YAML: &str = include_str!("../../models.yaml"); pub const MODELS_YAML: &str = include_str!("../../models.yaml");
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| { pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
Config::local_models_override() Config::local_models_override()
@@ -47,8 +47,6 @@ pub trait Client: Sync + Send {
fn model(&self) -> &Model; fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
fn build_client(&self) -> Result<ReqwestClient> { fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder(); let mut builder = ReqwestClient::builder();
let extra = self.extra_config(); let extra = self.extra_config();
@@ -291,9 +289,6 @@ pub struct ChatCompletionsData {
pub struct ChatCompletionsOutput { pub struct ChatCompletionsOutput {
pub text: String, pub text: String,
pub tool_calls: Vec<ToolCall>, pub tool_calls: Vec<ToolCall>,
pub id: Option<String>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
} }
impl ChatCompletionsOutput { impl ChatCompletionsOutput {
@@ -341,7 +336,6 @@ pub type RerankOutput = Vec<RerankResult>;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct RerankResult { pub struct RerankResult {
pub index: usize, pub index: usize,
pub relevance_score: f64,
} }
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool); pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool);
-4
View File
@@ -159,10 +159,6 @@ macro_rules! client_common_fns {
fn model(&self) -> &Model { fn model(&self) -> &Model {
&self.model &self.model
} }
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
}; };
} }
-8
View File
@@ -118,14 +118,6 @@ impl Model {
} }
} }
pub fn data(&self) -> &ModelData {
&self.data
}
pub fn data_mut(&mut self) -> &mut ModelData {
&mut self.data
}
pub fn description(&self) -> String { pub fn description(&self) -> String {
match self.model_type() { match self.model_type() {
ModelType::Chat => { ModelType::Chat => {
+1 -7
View File
@@ -389,13 +389,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
} else { } else {
text.to_string() text.to_string()
}; };
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput { text, tool_calls };
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok(output) Ok(output)
} }
-5
View File
@@ -56,7 +56,6 @@ impl SseHandler {
} }
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> { pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
// debug!("HandleCall: {:?}", call);
self.tool_calls.push(call); self.tool_calls.push(call);
Ok(()) Ok(())
} }
@@ -65,10 +64,6 @@ impl SseHandler {
self.abort_signal.clone() self.abort_signal.clone()
} }
pub fn tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}
pub fn take(self) -> (String, Vec<ToolCall>) { pub fn take(self) -> (String, Vec<ToolCall>) {
let Self { let Self {
buffer, tool_calls, .. buffer, tool_calls, ..
+1 -7
View File
@@ -296,13 +296,7 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsO
bail!("Invalid response data: {data}"); bail!("Invalid response data: {data}");
} }
} }
let output = ChatCompletionsOutput { let output = ChatCompletionsOutput { text, tool_calls };
text,
tool_calls,
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
};
Ok(output) Ok(output)
} }
+4 -45
View File
@@ -78,8 +78,6 @@ const MCP_FILE_NAME: &str = "mcp.json";
const CLIENTS_FIELD: &str = "clients"; const CLIENTS_FIELD: &str = "clients";
const SERVE_ADDR: &str = "127.0.0.1:8000";
const SYNC_MODELS_URL: &str = const SYNC_MODELS_URL: &str =
"https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml"; "https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml";
@@ -578,30 +576,17 @@ impl Config {
flags flags
} }
pub fn serve_addr(&self) -> String { pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> {
self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into())
}
pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
let log_level = env::var(get_env_name("log_level")) let log_level = env::var(get_env_name("log_level"))
.ok() .ok()
.and_then(|v| v.parse().ok()) .and_then(|v| v.parse().ok())
.unwrap_or(match cfg!(debug_assertions) { .unwrap_or(match cfg!(debug_assertions) {
true => LevelFilter::Debug, true => LevelFilter::Debug,
false => { false => LevelFilter::Info,
if is_serve {
LevelFilter::Off
} else {
LevelFilter::Info
}
}
}); });
let log_path = match env::var(get_env_name("log_path")) { let log_path = match env::var(get_env_name("log_path")) {
Ok(v) => Some(PathBuf::from(v)), Ok(v) => Some(PathBuf::from(v)),
Err(_) => match is_serve { Err(_) => Some(Config::log_path()),
true => None,
false => Some(Config::log_path()),
},
}; };
Ok((log_level, log_path)) Ok((log_level, log_path))
} }
@@ -744,7 +729,7 @@ impl Config {
display_path(&self.vault_password_file()), display_path(&self.vault_password_file()),
), ),
]; ];
if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) { if let Ok((_, Some(log_path))) = Self::log_config() {
items.push(("log_path", display_path(&log_path))); items.push(("log_path", display_path(&log_path)));
} }
let output = items let output = items
@@ -1279,23 +1264,6 @@ impl Config {
Ok(()) Ok(())
} }
pub fn all_roles() -> Vec<Role> {
let mut roles: HashMap<String, Role> = Role::list_builtin_roles()
.iter()
.map(|v| (v.name().to_string(), v.clone()))
.collect();
let names = Self::list_roles(false);
for name in names {
if let Ok(content) = read_to_string(Self::role_file(&name)) {
let role = Role::new(&name, &content);
roles.insert(name, role);
}
}
let mut roles: Vec<_> = roles.into_values().collect();
roles.sort_unstable_by(|a, b| a.name().cmp(b.name()));
roles
}
pub fn list_roles(with_builtin: bool) -> Vec<String> { pub fn list_roles(with_builtin: bool) -> Vec<String> {
let mut names = HashSet::new(); let mut names = HashSet::new();
if let Ok(rd) = read_dir(Self::roles_dir()) { if let Ok(rd) = read_dir(Self::roles_dir()) {
@@ -1921,7 +1889,6 @@ impl Config {
let prelude = match self.working_mode { let prelude = match self.working_mode {
WorkingMode::Repl => self.repl_prelude.as_ref(), WorkingMode::Repl => self.repl_prelude.as_ref(),
WorkingMode::Cmd => self.cmd_prelude.as_ref(), WorkingMode::Cmd => self.cmd_prelude.as_ref(),
WorkingMode::Serve => return Ok(()),
}; };
let prelude = match prelude { let prelude = match prelude {
Some(v) => { Some(v) => {
@@ -2835,10 +2802,6 @@ impl Config {
if let Some(v) = read_env_value::<String>(&get_env_name("right_prompt")) { if let Some(v) = read_env_value::<String>(&get_env_name("right_prompt")) {
self.right_prompt = v; self.right_prompt = v;
} }
if let Some(v) = read_env_value::<String>(&get_env_name("serve_addr")) {
self.serve_addr = v;
}
if let Some(v) = read_env_value::<String>(&get_env_name("user_agent")) { if let Some(v) = read_env_value::<String>(&get_env_name("user_agent")) {
self.user_agent = v; self.user_agent = v;
} }
@@ -2947,7 +2910,6 @@ pub fn load_env_file() -> Result<()> {
pub enum WorkingMode { pub enum WorkingMode {
Cmd, Cmd,
Repl, Repl,
Serve,
} }
impl WorkingMode { impl WorkingMode {
@@ -2957,9 +2919,6 @@ impl WorkingMode {
pub fn is_repl(&self) -> bool { pub fn is_repl(&self) -> bool {
*self == WorkingMode::Repl *self == WorkingMode::Repl
} }
pub fn is_serve(&self) -> bool {
*self == WorkingMode::Serve
}
} }
#[async_recursion::async_recursion] #[async_recursion::async_recursion]
-6
View File
@@ -111,12 +111,6 @@ impl Role {
.collect() .collect()
} }
pub fn list_builtin_roles() -> Vec<Self> {
RolesAsset::iter()
.filter_map(|v| Role::builtin(&v).ok())
.collect()
}
pub fn has_args(&self) -> bool { pub fn has_args(&self) -> bool {
self.name.contains('#') self.name.contains('#')
} }
+5 -18
View File
@@ -5,7 +5,6 @@ mod function;
mod rag; mod rag;
mod render; mod render;
mod repl; mod repl;
mod serve;
#[macro_use] #[macro_use]
mod utils; mod utils;
mod mcp; mod mcp;
@@ -58,9 +57,7 @@ async fn main() -> Result<()> {
} }
let text = cli.text()?; let text = cli.text()?;
let working_mode = if cli.serve.is_some() { let working_mode = if text.is_none() && cli.file.is_empty() {
WorkingMode::Serve
} else if text.is_none() && cli.file.is_empty() {
WorkingMode::Repl WorkingMode::Repl
} else { } else {
WorkingMode::Cmd WorkingMode::Cmd
@@ -80,7 +77,7 @@ async fn main() -> Result<()> {
|| cli.delete_secret.is_some() || cli.delete_secret.is_some()
|| cli.list_secrets; || cli.list_secrets;
let log_path = setup_logger(working_mode.is_serve())?; let log_path = setup_logger()?;
if vault_flags { if vault_flags {
return Vault::handle_vault_flags(cli, Config::init_bare()?); return Vault::handle_vault_flags(cli, Config::init_bare()?);
@@ -219,9 +216,6 @@ async fn run(
println!("{info}"); println!("{info}");
return Ok(()); return Ok(());
} }
if let Some(addr) = cli.serve {
return serve::run(config, addr).await;
}
let is_repl = config.read().working_mode.is_repl(); let is_repl = config.read().working_mode.is_repl();
if cli.rebuild_rag { if cli.rebuild_rag {
Config::rebuild_rag(&config, abort_signal.clone()).await?; Config::rebuild_rag(&config, abort_signal.clone()).await?;
@@ -429,22 +423,15 @@ async fn create_input(
Ok(input) Ok(input)
} }
fn setup_logger(is_serve: bool) -> Result<Option<PathBuf>> { fn setup_logger() -> Result<Option<PathBuf>> {
let (log_level, log_path) = Config::log_config(is_serve)?; let (log_level, log_path) = Config::log_config()?;
if log_level == LevelFilter::Off { if log_level == LevelFilter::Off {
return Ok(None); return Ok(None);
} }
let encoder = Box::new(PatternEncoder::new( let encoder = Box::new(PatternEncoder::new(
"{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}", "{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}",
)); ));
let log_filter = match env::var(get_env_name("log_filter")) { let log_filter = env::var(get_env_name("log_filter")).ok();
Ok(v) => Some(v),
Err(_) => match is_serve {
true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))),
false => None,
},
};
match log_path.clone() { match log_path.clone() {
None => { None => {
let console_appender = ConsoleAppender::builder().encoder(encoder).build(); let console_appender = ConsoleAppender::builder().encoder(encoder).build();
+2 -2
View File
@@ -41,8 +41,8 @@ impl Prompt for ReplPrompt {
PromptHistorySearchStatus::Passing => "", PromptHistorySearchStatus::Passing => "",
PromptHistorySearchStatus::Failing => "failing ", PromptHistorySearchStatus::Failing => "failing ",
}; };
// NOTE: magic strings, given there is logic on how these compose I am not sure if it // NOTE: magic strings, given there is logic on how these are composed, I'm unsure if it's
// is worth extracting in to static constant // worth extracting into a static constant
Cow::Owned(format!( Cow::Owned(format!(
"({}reverse-search: {}) ", "({}reverse-search: {}) ",
prefix, history_search.term prefix, history_search.term
-935
View File
@@ -1,935 +0,0 @@
use crate::{client::*, config::*, function::*, rag::*, utils::*};
use anyhow::{anyhow, bail, Result};
use bytes::Bytes;
use chrono::{Timelike, Utc};
use futures_util::StreamExt;
use http::{Method, Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
use hyper::{
body::{Frame, Incoming},
service::service_fn,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::{json, Value};
use std::{
convert::Infallible,
net::IpAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::{
net::TcpListener,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
};
use tokio_graceful::Shutdown;
use tokio_stream::wrappers::UnboundedReceiverStream;
const DEFAULT_MODEL_NAME: &str = "default";
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let addr = match addr {
Some(addr) => {
if let Ok(port) = addr.parse::<u16>() {
format!("127.0.0.1:{port}")
} else if let Ok(ip) = addr.parse::<IpAddr>() {
format!("{ip}:8000")
} else {
addr
}
}
None => config.read().serve_addr(),
};
let server = Arc::new(Server::new(&config));
let listener = TcpListener::bind(&addr).await?;
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("Rerank API: http://{addr}/v1/rerank");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena?num=2");
shutdown_signal().await;
let _ = stop_server.send(());
Ok(())
}
struct Server {
config: Config,
models: Vec<Value>,
roles: Vec<Role>,
rags: Vec<String>,
}
impl Server {
fn new(config: &GlobalConfig) -> Self {
let mut config = config.read().clone();
config.functions = Functions::default();
let mut models = list_all_models(&config);
let mut default_model = config.model.clone();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
let models: Vec<Value> = models
.into_iter()
.enumerate()
.map(|(i, model)| {
let id = if i == 0 {
DEFAULT_MODEL_NAME.into()
} else {
model.id()
};
let mut value = json!(model.data());
if let Some(value_obj) = value.as_object_mut() {
value_obj.insert("id".into(), id.into());
value_obj.insert("object".into(), "model".into());
value_obj.insert("owned_by".into(), model.client_name().into());
value_obj.remove("name");
}
value
})
.collect();
Self {
config,
models,
roles: Config::all_roles(),
rags: Config::list_rags(),
}
}
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
let guard = shutdown.guard_weak();
loop {
tokio::select! {
res = listener.accept() => {
let Ok((cnx, _)) = res else {
continue;
};
let stream = TokioIo::new(cnx);
let server = self.clone();
shutdown.spawn_task(async move {
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
server.clone().handle(request)
});
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
});
}
_ = guard.cancelled() => {
break;
}
}
}
});
Ok(tx)
}
async fn handle(
self: Arc<Self>,
req: hyper::Request<Incoming>,
) -> std::result::Result<AppResponse, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path();
if method == Method::OPTIONS {
let mut res = Response::default();
*res.status_mut() = StatusCode::NO_CONTENT;
set_cors_header(&mut res);
return Ok(res);
}
let mut status = StatusCode::OK;
let res = if path == "/v1/chat/completions" {
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/rerank" {
self.rerank(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
self.list_roles()
} else if path == "/v1/rags" {
self.list_rags()
} else if path == "/v1/rags/search" {
self.search_rag(req).await
} else if path == "/playground" || path == "/playground.html" {
self.playground_page()
} else if path == "/arena" || path == "/arena.html" {
self.arena_page()
} else {
status = StatusCode::NOT_FOUND;
Err(anyhow!("Not Found"))
};
let mut res = match res {
Ok(res) => {
info!("{method} {uri} {}", status.as_u16());
res
}
Err(err) => {
if status == StatusCode::OK {
status = StatusCode::BAD_REQUEST;
}
error!("{method} {uri} {} {err}", status.as_u16());
ret_err(err)
}
};
*res.status_mut() = status;
set_cors_header(&mut res);
Ok(res)
}
fn playground_page(&self) -> Result<AppResponse> {
let res = Response::builder()
.header("Content-Type", "text/html; charset=utf-8")
.body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?;
Ok(res)
}
fn arena_page(&self) -> Result<AppResponse> {
let res = Response::builder()
.header("Content-Type", "text/html; charset=utf-8")
.body(Full::new(Bytes::from(ARENA_HTML)).boxed())?;
Ok(res)
}
fn list_models(&self) -> Result<AppResponse> {
let data = json!({ "data": self.models });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
fn list_roles(&self) -> Result<AppResponse> {
let data = json!({ "data": self.roles });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
fn list_rags(&self) -> Result<AppResponse> {
let data = json!({ "data": self.rags });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
async fn search_rag(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("search rag request: {req_body}");
let SearchRagReqBody { name, input } = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let config = Arc::new(RwLock::new(self.config.clone()));
let abort_signal = create_abort_signal();
let rag_path = config.read().rag_file(&name);
let rag = Rag::load(&config, &name, &rag_path)?;
let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?;
let data = json!({ "data": rag_result });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("chat completions request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let ChatCompletionsReqBody {
model,
messages,
temperature,
top_p,
max_tokens,
stream,
tools,
} = req_body;
let mut messages =
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let config = self.config.clone();
let default_model = config.model.clone();
let config = Arc::new(RwLock::new(config));
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
(default_model.id(), true)
} else if default_model.id() == model {
(model, false)
} else {
(model, true)
};
if change {
config.write().set_model(&model_name)?;
}
let mut client = init_client(&config, None)?;
if max_tokens.is_some() {
client.model_mut().set_max_tokens(max_tokens, true);
}
let abort_signal = create_abort_signal();
let http_client = client.build_client()?;
let completion_id = generate_completion_id();
let created = Utc::now().timestamp();
patch_messages(&mut messages, client.model());
let data: ChatCompletionsData = ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
};
if stream {
let (tx, mut rx) = unbounded_channel();
tokio::spawn(async move {
let is_first = Arc::new(AtomicBool::new(true));
let (sse_tx, sse_rx) = unbounded_channel();
let mut handler = SseHandler::new(sse_tx, abort_signal);
async fn map_event(
mut sse_rx: UnboundedReceiver<SseEvent>,
tx: &UnboundedSender<ResEvent>,
is_first: Arc<AtomicBool>,
) {
while let Some(reply_event) = sse_rx.recv().await {
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(None));
is_first.store(false, Ordering::SeqCst)
}
match reply_event {
SseEvent::Text(text) => {
let _ = tx.send(ResEvent::Text(text));
}
SseEvent::Done => {
let _ = tx.send(ResEvent::Done);
sse_rx.close();
}
}
}
}
async fn chat_completions(
client: &dyn Client,
http_client: &reqwest::Client,
handler: &mut SseHandler,
mut data: ChatCompletionsData,
tx: &UnboundedSender<ResEvent>,
is_first: Arc<AtomicBool>,
) {
if client.model().no_stream() {
data.stream = false;
let ret = client.chat_completions_inner(http_client, data).await;
match ret {
Ok(output) => {
let ChatCompletionsOutput {
text, tool_calls, ..
} = output;
let _ = tx.send(ResEvent::First(None));
is_first.store(false, Ordering::SeqCst);
let _ = tx.send(ResEvent::Text(text));
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
Err(err) => {
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
is_first.store(false, Ordering::SeqCst)
}
};
} else {
let ret = client
.chat_completions_streaming_inner(http_client, handler, data)
.await;
let first = match ret {
Ok(()) => None,
Err(err) => Some(format!("{err:?}")),
};
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(first));
is_first.store(false, Ordering::SeqCst)
}
let tool_calls = handler.tool_calls().to_vec();
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
handler.done();
}
tokio::join!(
map_event(sse_rx, &tx, is_first.clone()),
chat_completions(
client.as_ref(),
&http_client,
&mut handler,
data,
&tx,
is_first
),
);
});
let first_event = rx.recv().await;
if let Some(ResEvent::First(Some(err))) = first_event {
bail!("{err}");
}
let shared: Arc<(String, String, i64, AtomicBool)> =
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
let stream = UnboundedReceiverStream::new(rx);
let stream = stream.filter_map(move |res_event| {
let shared = shared.clone();
async move {
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
match res_event {
ResEvent::Text(text) => {
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
}
ResEvent::ToolCalls(tool_calls) => {
has_tool_calls.store(true, Ordering::SeqCst);
Some(Ok(create_tool_calls_frame(
completion_id,
model,
*created,
&tool_calls,
)))
}
ResEvent::Done => Some(Ok(create_done_frame(
completion_id,
model,
*created,
has_tool_calls.load(Ordering::SeqCst),
))),
_ => None,
}
}
});
let res = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let output = client.chat_completions_inner(&http_client, data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(
Full::new(ret_non_stream(
&completion_id,
&model_name,
created,
&output,
))
.boxed(),
)?;
Ok(res)
}
}
async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("embeddings request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let EmbeddingsReqBody {
input,
model: embedding_model_id,
} = req_body;
let config = Arc::new(RwLock::new(self.config.clone()));
let embedding_model =
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
let texts = match input {
EmbeddingsReqBodyInput::Single(v) => vec![v],
EmbeddingsReqBodyInput::Multiple(v) => v,
};
let client = init_client(&config, Some(embedding_model))?;
let data = client
.embeddings(&EmbeddingsData {
query: false,
texts,
})
.await?;
let data: Vec<_> = data
.into_iter()
.enumerate()
.map(|(i, v)| {
json!({
"object": "embedding",
"embedding": v,
"index": i,
})
})
.collect();
let output = json!({
"object": "list",
"data": data,
"model": embedding_model_id,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("rerank request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let RerankReqBody {
model: reranker_model_id,
documents,
query,
top_n,
} = req_body;
let top_n = top_n.unwrap_or(documents.len());
let config = Arc::new(RwLock::new(self.config.clone()));
let reranker_model =
Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?;
let client = init_client(&config, Some(reranker_model))?;
let data = client
.rerank(&RerankData {
query,
documents: documents.clone(),
top_n,
})
.await?;
let results: Vec<_> = data
.into_iter()
.map(|v| {
json!({
"index": v.index,
"relevance_score": v.relevance_score,
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
})
})
.collect();
let output = json!({
"id": uuid::Uuid::new_v4().to_string(),
"results": results,
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}
#[derive(Debug, Deserialize)]
struct SearchRagReqBody {
name: String,
input: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionsReqBody {
model: String,
messages: Vec<Value>,
temperature: Option<f64>,
top_p: Option<f64>,
max_tokens: Option<isize>,
#[serde(default)]
stream: bool,
tools: Option<Vec<Value>>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
input: EmbeddingsReqBodyInput,
model: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum EmbeddingsReqBodyInput {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Deserialize)]
struct RerankReqBody {
documents: Vec<String>,
query: String,
model: String,
top_n: Option<usize>,
}
#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Text(String),
ToolCalls(Vec<ToolCall>),
Done,
}
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler")
}
fn generate_completion_id() -> String {
let random_id = Utc::now().nanosecond();
format!("chatcmpl-{random_id}")
}
fn set_cors_header(res: &mut AppResponse) {
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
hyper::header::HeaderValue::from_static("*"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
);
}
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
let delta = if content.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({ "content": content })
};
let choice = json!({
"index": 0,
"delta": delta,
"finish_reason": null,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\n")))
}
fn create_tool_calls_frame(
id: &str,
model: &str,
created: i64,
tool_calls: &[ToolCall],
) -> Frame<Bytes> {
let chunks = tool_calls
.iter()
.enumerate()
.flat_map(|(i, call)| {
let choice1 = json!({
"index": 0,
"delta": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"index": i,
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": ""
}
}
]
},
"finish_reason": null
});
let choice2 = json!({
"index": 0,
"delta": {
"tool_calls": [
{
"index": i,
"function": {
"arguments": call.arguments.to_string(),
}
}
]
},
"finish_reason": null
});
vec![
build_chat_completion_chunk_json(id, model, created, &choice1),
build_chat_completion_chunk_json(id, model, created, &choice2),
]
})
.map(|v| format!("data: {v}\n\n"))
.collect::<Vec<String>>()
.join("");
Frame::data(Bytes::from(chunks))
}
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
let choice = json!({
"index": 0,
"delta": {},
"finish_reason": finish_reason,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
}
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [choice],
})
}
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
let id = output.id.as_deref().unwrap_or(id);
let input_tokens = output.input_tokens.unwrap_or_default();
let output_tokens = output.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens;
let choice = if output.tool_calls.is_empty() {
json!({
"index": 0,
"message": {
"role": "assistant",
"content": output.text,
},
"logprobs": null,
"finish_reason": "stop",
})
} else {
let content = if output.text.is_empty() {
Value::Null
} else {
output.text.clone().into()
};
let tool_calls: Vec<_> = output
.tool_calls
.iter()
.map(|call| {
json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.arguments.to_string(),
}
})
})
.collect();
json!({
"index": 0,
"message": {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
},
"logprobs": null,
"finish_reason": "tool_calls",
})
};
let res_body = json!({
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [choice],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": total_tokens,
},
});
Bytes::from(res_body.to_string())
}
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
let data = json!({
"error": {
"message": err.to_string(),
"type": "invalid_request_error",
},
});
Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(data.to_string())).boxed())
.unwrap()
}
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
let mut output = vec![];
let mut tool_results = None;
for (i, message) in message.into_iter().enumerate() {
let err = || anyhow!("Failed to parse '.messages[{i}]'");
let role = message["role"].as_str().ok_or_else(err)?;
let content = match message.get("content") {
Some(value) => {
if let Some(value) = value.as_str() {
MessageContent::Text(value.to_string())
} else if value.is_array() {
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
MessageContent::Array(value)
} else if value.is_null() {
MessageContent::Text(String::new())
} else {
return Err(err());
}
}
None => MessageContent::Text(String::new()),
};
match role {
"system" | "user" => {
let role = match role {
"system" => MessageRole::System,
"user" => MessageRole::User,
_ => unreachable!(),
};
output.push(Message::new(role, content))
}
"assistant" => {
let role = MessageRole::Assistant;
match message["tool_calls"].as_array() {
Some(tool_calls) => {
if tool_results.is_some() {
return Err(err());
}
let mut list = vec![];
for tool_call in tool_calls {
if let (id, Some(name), Some(arguments)) = (
tool_call["id"].as_str().map(|v| v.to_string()),
tool_call["function"]["name"].as_str(),
tool_call["function"]["arguments"].as_str(),
) {
let arguments =
serde_json::from_str(arguments).map_err(|_| err())?;
list.push((id, name.to_string(), arguments));
} else {
return Err(err());
}
}
tool_results = Some((content.to_text(), list, vec![]));
}
None => output.push(Message::new(role, content)),
}
}
"tool" => match tool_results.take() {
Some((text, tool_calls, mut tool_values)) => {
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
let content = content.to_text();
let value: Value = serde_json::from_str(&content)
.ok()
.unwrap_or_else(|| content.into());
tool_values.push((value, tool_call_id));
if tool_calls.len() == tool_values.len() {
let mut list = vec![];
for ((id, name, arguments), (value, tool_call_id)) in
tool_calls.into_iter().zip(tool_values.into_iter())
{
if id != tool_call_id {
return Err(err());
}
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
}
output.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
));
tool_results = None;
} else {
tool_results = Some((text, tool_calls, tool_values));
}
}
None => return Err(err()),
},
_ => {
return Err(err());
}
}
}
if tool_results.is_some() {
bail!("Invalid messages");
}
Ok(output)
}
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
let tools = match tools {
Some(v) => v,
None => return Ok(None),
};
let mut functions = vec![];
for (i, tool) in tools.into_iter().enumerate() {
if let (Some("function"), Some(function)) = (
tool["type"].as_str(),
tool["function"]
.as_object()
.and_then(|v| serde_json::from_value(json!(v)).ok()),
) {
functions.push(function);
} else {
bail!("Failed to parse '.tools[{i}]'")
}
}
Ok(Some(functions))
}
+1 -1
View File
@@ -2,7 +2,7 @@ use std::collections::HashMap;
/// Render REPL prompt /// Render REPL prompt
/// ///
/// The template comprises plain text and `{...}`. /// The template comprises of plain text and `{...}`.
/// ///
/// The syntax of `{...}`: /// The syntax of `{...}`:
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template` /// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`