feat: Removed the server functionality
This commit is contained in:
+1
-2
@@ -16,7 +16,7 @@ vault_password_file: null # Path to a file containing the password for th
|
|||||||
function_calling: true # Enables or disables function calling (Globally).
|
function_calling: true # Enables or disables function calling (Globally).
|
||||||
mapping_tools: # Alias for a tool or toolset
|
mapping_tools: # Alias for a tool or toolset
|
||||||
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
|
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
|
||||||
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search')
|
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search_loki')
|
||||||
|
|
||||||
# ---- MCP Servers ----
|
# ---- MCP Servers ----
|
||||||
mcp_servers: true # Enables or disables MCP servers (globally).
|
mcp_servers: true # Enables or disables MCP servers (globally).
|
||||||
@@ -85,7 +85,6 @@ right_prompt:
|
|||||||
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
|
'{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
|
||||||
|
|
||||||
# ---- misc ----
|
# ---- misc ----
|
||||||
serve_addr: 127.0.0.1:8000 # Server listening address
|
|
||||||
user_agent: null # Set User-Agent HTTP header, use `auto` for loki/<current-version>
|
user_agent: null # Set User-Agent HTTP header, use `auto` for loki/<current-version>
|
||||||
save_shell_history: true # Whether to save shell execution command to the history file
|
save_shell_history: true # Whether to save shell execution command to the history file
|
||||||
# URL to sync model changes from
|
# URL to sync model changes from
|
||||||
|
|||||||
@@ -61,9 +61,6 @@ pub struct Cli {
|
|||||||
/// Execute a macro
|
/// Execute a macro
|
||||||
#[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))]
|
#[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))]
|
||||||
pub macro_name: Option<String>,
|
pub macro_name: Option<String>,
|
||||||
/// Serve the LLM API and WebAPP
|
|
||||||
#[arg(long, value_name = "PORT|IP|IP:PORT")]
|
|
||||||
pub serve: Option<Option<String>>,
|
|
||||||
/// Execute commands in natural language
|
/// Execute commands in natural language
|
||||||
#[arg(short = 'e', long)]
|
#[arg(short = 'e', long)]
|
||||||
pub execute: bool,
|
pub execute: bool,
|
||||||
|
|||||||
@@ -518,13 +518,7 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|||||||
bail!("Invalid response data: {data}");
|
bail!("Invalid response data: {data}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = ChatCompletionsOutput {
|
let output = ChatCompletionsOutput { text, tool_calls };
|
||||||
text,
|
|
||||||
tool_calls,
|
|
||||||
id: None,
|
|
||||||
input_tokens: data["usage"]["inputTokens"].as_u64(),
|
|
||||||
output_tokens: data["usage"]["outputTokens"].as_u64(),
|
|
||||||
};
|
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ use super::*;
|
|||||||
|
|
||||||
use crate::utils::strip_think_tag;
|
use crate::utils::strip_think_tag;
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{Context, Result, bail};
|
||||||
use reqwest::RequestBuilder;
|
use reqwest::RequestBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
const API_BASE: &str = "https://api.anthropic.com/v1";
|
const API_BASE: &str = "https://api.anthropic.com/v1";
|
||||||
|
|
||||||
@@ -301,7 +301,7 @@ pub fn claude_build_chat_completions_body(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
Ok(body)
|
Ok(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,9 +353,6 @@ pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
|
|||||||
let output = ChatCompletionsOutput {
|
let output = ChatCompletionsOutput {
|
||||||
text: text.to_string(),
|
text: text.to_string(),
|
||||||
tool_calls,
|
tool_calls,
|
||||||
id: data["id"].as_str().map(|v| v.to_string()),
|
|
||||||
input_tokens: data["usage"]["input_tokens"].as_u64(),
|
|
||||||
output_tokens: data["usage"]["output_tokens"].as_u64(),
|
|
||||||
};
|
};
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,12 +244,6 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|||||||
if text.is_empty() && tool_calls.is_empty() {
|
if text.is_empty() && tool_calls.is_empty() {
|
||||||
bail!("Invalid response data: {data}");
|
bail!("Invalid response data: {data}");
|
||||||
}
|
}
|
||||||
let output = ChatCompletionsOutput {
|
let output = ChatCompletionsOutput { text, tool_calls };
|
||||||
text,
|
|
||||||
tool_calls,
|
|
||||||
id: data["id"].as_str().map(|v| v.to_string()),
|
|
||||||
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
|
|
||||||
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
|
|
||||||
};
|
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ use std::sync::LazyLock;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::sync::mpsc::unbounded_channel;
|
use tokio::sync::mpsc::unbounded_channel;
|
||||||
|
|
||||||
const MODELS_YAML: &str = include_str!("../../models.yaml");
|
pub const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||||
|
|
||||||
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
||||||
Config::local_models_override()
|
Config::local_models_override()
|
||||||
@@ -47,8 +47,6 @@ pub trait Client: Sync + Send {
|
|||||||
|
|
||||||
fn model(&self) -> &Model;
|
fn model(&self) -> &Model;
|
||||||
|
|
||||||
fn model_mut(&mut self) -> &mut Model;
|
|
||||||
|
|
||||||
fn build_client(&self) -> Result<ReqwestClient> {
|
fn build_client(&self) -> Result<ReqwestClient> {
|
||||||
let mut builder = ReqwestClient::builder();
|
let mut builder = ReqwestClient::builder();
|
||||||
let extra = self.extra_config();
|
let extra = self.extra_config();
|
||||||
@@ -291,9 +289,6 @@ pub struct ChatCompletionsData {
|
|||||||
pub struct ChatCompletionsOutput {
|
pub struct ChatCompletionsOutput {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
pub tool_calls: Vec<ToolCall>,
|
pub tool_calls: Vec<ToolCall>,
|
||||||
pub id: Option<String>,
|
|
||||||
pub input_tokens: Option<u64>,
|
|
||||||
pub output_tokens: Option<u64>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletionsOutput {
|
impl ChatCompletionsOutput {
|
||||||
@@ -341,7 +336,6 @@ pub type RerankOutput = Vec<RerankResult>;
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct RerankResult {
|
pub struct RerankResult {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
pub relevance_score: f64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool);
|
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>, bool);
|
||||||
|
|||||||
@@ -159,10 +159,6 @@ macro_rules! client_common_fns {
|
|||||||
fn model(&self) -> &Model {
|
fn model(&self) -> &Model {
|
||||||
&self.model
|
&self.model
|
||||||
}
|
}
|
||||||
|
|
||||||
fn model_mut(&mut self) -> &mut Model {
|
|
||||||
&mut self.model
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -118,14 +118,6 @@ impl Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn data(&self) -> &ModelData {
|
|
||||||
&self.data
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn data_mut(&mut self) -> &mut ModelData {
|
|
||||||
&mut self.data
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn description(&self) -> String {
|
pub fn description(&self) -> String {
|
||||||
match self.model_type() {
|
match self.model_type() {
|
||||||
ModelType::Chat => {
|
ModelType::Chat => {
|
||||||
|
|||||||
@@ -389,13 +389,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
|
|||||||
} else {
|
} else {
|
||||||
text.to_string()
|
text.to_string()
|
||||||
};
|
};
|
||||||
let output = ChatCompletionsOutput {
|
let output = ChatCompletionsOutput { text, tool_calls };
|
||||||
text,
|
|
||||||
tool_calls,
|
|
||||||
id: data["id"].as_str().map(|v| v.to_string()),
|
|
||||||
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
|
|
||||||
output_tokens: data["usage"]["completion_tokens"].as_u64(),
|
|
||||||
};
|
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ impl SseHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
||||||
// debug!("HandleCall: {:?}", call);
|
|
||||||
self.tool_calls.push(call);
|
self.tool_calls.push(call);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -65,10 +64,6 @@ impl SseHandler {
|
|||||||
self.abort_signal.clone()
|
self.abort_signal.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_calls(&self) -> &[ToolCall] {
|
|
||||||
&self.tool_calls
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn take(self) -> (String, Vec<ToolCall>) {
|
pub fn take(self) -> (String, Vec<ToolCall>) {
|
||||||
let Self {
|
let Self {
|
||||||
buffer, tool_calls, ..
|
buffer, tool_calls, ..
|
||||||
|
|||||||
@@ -296,13 +296,7 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsO
|
|||||||
bail!("Invalid response data: {data}");
|
bail!("Invalid response data: {data}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let output = ChatCompletionsOutput {
|
let output = ChatCompletionsOutput { text, tool_calls };
|
||||||
text,
|
|
||||||
tool_calls,
|
|
||||||
id: None,
|
|
||||||
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
|
|
||||||
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
|
|
||||||
};
|
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+4
-45
@@ -78,8 +78,6 @@ const MCP_FILE_NAME: &str = "mcp.json";
|
|||||||
|
|
||||||
const CLIENTS_FIELD: &str = "clients";
|
const CLIENTS_FIELD: &str = "clients";
|
||||||
|
|
||||||
const SERVE_ADDR: &str = "127.0.0.1:8000";
|
|
||||||
|
|
||||||
const SYNC_MODELS_URL: &str =
|
const SYNC_MODELS_URL: &str =
|
||||||
"https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml";
|
"https://raw.githubusercontent.com/Dark-Alex-17/loki/refs/heads/main/models.yaml";
|
||||||
|
|
||||||
@@ -578,30 +576,17 @@ impl Config {
|
|||||||
flags
|
flags
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serve_addr(&self) -> String {
|
pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> {
|
||||||
self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
|
|
||||||
let log_level = env::var(get_env_name("log_level"))
|
let log_level = env::var(get_env_name("log_level"))
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|v| v.parse().ok())
|
.and_then(|v| v.parse().ok())
|
||||||
.unwrap_or(match cfg!(debug_assertions) {
|
.unwrap_or(match cfg!(debug_assertions) {
|
||||||
true => LevelFilter::Debug,
|
true => LevelFilter::Debug,
|
||||||
false => {
|
false => LevelFilter::Info,
|
||||||
if is_serve {
|
|
||||||
LevelFilter::Off
|
|
||||||
} else {
|
|
||||||
LevelFilter::Info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
let log_path = match env::var(get_env_name("log_path")) {
|
let log_path = match env::var(get_env_name("log_path")) {
|
||||||
Ok(v) => Some(PathBuf::from(v)),
|
Ok(v) => Some(PathBuf::from(v)),
|
||||||
Err(_) => match is_serve {
|
Err(_) => Some(Config::log_path()),
|
||||||
true => None,
|
|
||||||
false => Some(Config::log_path()),
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
Ok((log_level, log_path))
|
Ok((log_level, log_path))
|
||||||
}
|
}
|
||||||
@@ -744,7 +729,7 @@ impl Config {
|
|||||||
display_path(&self.vault_password_file()),
|
display_path(&self.vault_password_file()),
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) {
|
if let Ok((_, Some(log_path))) = Self::log_config() {
|
||||||
items.push(("log_path", display_path(&log_path)));
|
items.push(("log_path", display_path(&log_path)));
|
||||||
}
|
}
|
||||||
let output = items
|
let output = items
|
||||||
@@ -1279,23 +1264,6 @@ impl Config {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn all_roles() -> Vec<Role> {
|
|
||||||
let mut roles: HashMap<String, Role> = Role::list_builtin_roles()
|
|
||||||
.iter()
|
|
||||||
.map(|v| (v.name().to_string(), v.clone()))
|
|
||||||
.collect();
|
|
||||||
let names = Self::list_roles(false);
|
|
||||||
for name in names {
|
|
||||||
if let Ok(content) = read_to_string(Self::role_file(&name)) {
|
|
||||||
let role = Role::new(&name, &content);
|
|
||||||
roles.insert(name, role);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut roles: Vec<_> = roles.into_values().collect();
|
|
||||||
roles.sort_unstable_by(|a, b| a.name().cmp(b.name()));
|
|
||||||
roles
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn list_roles(with_builtin: bool) -> Vec<String> {
|
pub fn list_roles(with_builtin: bool) -> Vec<String> {
|
||||||
let mut names = HashSet::new();
|
let mut names = HashSet::new();
|
||||||
if let Ok(rd) = read_dir(Self::roles_dir()) {
|
if let Ok(rd) = read_dir(Self::roles_dir()) {
|
||||||
@@ -1921,7 +1889,6 @@ impl Config {
|
|||||||
let prelude = match self.working_mode {
|
let prelude = match self.working_mode {
|
||||||
WorkingMode::Repl => self.repl_prelude.as_ref(),
|
WorkingMode::Repl => self.repl_prelude.as_ref(),
|
||||||
WorkingMode::Cmd => self.cmd_prelude.as_ref(),
|
WorkingMode::Cmd => self.cmd_prelude.as_ref(),
|
||||||
WorkingMode::Serve => return Ok(()),
|
|
||||||
};
|
};
|
||||||
let prelude = match prelude {
|
let prelude = match prelude {
|
||||||
Some(v) => {
|
Some(v) => {
|
||||||
@@ -2835,10 +2802,6 @@ impl Config {
|
|||||||
if let Some(v) = read_env_value::<String>(&get_env_name("right_prompt")) {
|
if let Some(v) = read_env_value::<String>(&get_env_name("right_prompt")) {
|
||||||
self.right_prompt = v;
|
self.right_prompt = v;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(v) = read_env_value::<String>(&get_env_name("serve_addr")) {
|
|
||||||
self.serve_addr = v;
|
|
||||||
}
|
|
||||||
if let Some(v) = read_env_value::<String>(&get_env_name("user_agent")) {
|
if let Some(v) = read_env_value::<String>(&get_env_name("user_agent")) {
|
||||||
self.user_agent = v;
|
self.user_agent = v;
|
||||||
}
|
}
|
||||||
@@ -2947,7 +2910,6 @@ pub fn load_env_file() -> Result<()> {
|
|||||||
pub enum WorkingMode {
|
pub enum WorkingMode {
|
||||||
Cmd,
|
Cmd,
|
||||||
Repl,
|
Repl,
|
||||||
Serve,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkingMode {
|
impl WorkingMode {
|
||||||
@@ -2957,9 +2919,6 @@ impl WorkingMode {
|
|||||||
pub fn is_repl(&self) -> bool {
|
pub fn is_repl(&self) -> bool {
|
||||||
*self == WorkingMode::Repl
|
*self == WorkingMode::Repl
|
||||||
}
|
}
|
||||||
pub fn is_serve(&self) -> bool {
|
|
||||||
*self == WorkingMode::Serve
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_recursion::async_recursion]
|
#[async_recursion::async_recursion]
|
||||||
|
|||||||
@@ -111,12 +111,6 @@ impl Role {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_builtin_roles() -> Vec<Self> {
|
|
||||||
RolesAsset::iter()
|
|
||||||
.filter_map(|v| Role::builtin(&v).ok())
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn has_args(&self) -> bool {
|
pub fn has_args(&self) -> bool {
|
||||||
self.name.contains('#')
|
self.name.contains('#')
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-18
@@ -5,7 +5,6 @@ mod function;
|
|||||||
mod rag;
|
mod rag;
|
||||||
mod render;
|
mod render;
|
||||||
mod repl;
|
mod repl;
|
||||||
mod serve;
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod utils;
|
mod utils;
|
||||||
mod mcp;
|
mod mcp;
|
||||||
@@ -58,9 +57,7 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let text = cli.text()?;
|
let text = cli.text()?;
|
||||||
let working_mode = if cli.serve.is_some() {
|
let working_mode = if text.is_none() && cli.file.is_empty() {
|
||||||
WorkingMode::Serve
|
|
||||||
} else if text.is_none() && cli.file.is_empty() {
|
|
||||||
WorkingMode::Repl
|
WorkingMode::Repl
|
||||||
} else {
|
} else {
|
||||||
WorkingMode::Cmd
|
WorkingMode::Cmd
|
||||||
@@ -80,7 +77,7 @@ async fn main() -> Result<()> {
|
|||||||
|| cli.delete_secret.is_some()
|
|| cli.delete_secret.is_some()
|
||||||
|| cli.list_secrets;
|
|| cli.list_secrets;
|
||||||
|
|
||||||
let log_path = setup_logger(working_mode.is_serve())?;
|
let log_path = setup_logger()?;
|
||||||
|
|
||||||
if vault_flags {
|
if vault_flags {
|
||||||
return Vault::handle_vault_flags(cli, Config::init_bare()?);
|
return Vault::handle_vault_flags(cli, Config::init_bare()?);
|
||||||
@@ -219,9 +216,6 @@ async fn run(
|
|||||||
println!("{info}");
|
println!("{info}");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
if let Some(addr) = cli.serve {
|
|
||||||
return serve::run(config, addr).await;
|
|
||||||
}
|
|
||||||
let is_repl = config.read().working_mode.is_repl();
|
let is_repl = config.read().working_mode.is_repl();
|
||||||
if cli.rebuild_rag {
|
if cli.rebuild_rag {
|
||||||
Config::rebuild_rag(&config, abort_signal.clone()).await?;
|
Config::rebuild_rag(&config, abort_signal.clone()).await?;
|
||||||
@@ -429,22 +423,15 @@ async fn create_input(
|
|||||||
Ok(input)
|
Ok(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setup_logger(is_serve: bool) -> Result<Option<PathBuf>> {
|
fn setup_logger() -> Result<Option<PathBuf>> {
|
||||||
let (log_level, log_path) = Config::log_config(is_serve)?;
|
let (log_level, log_path) = Config::log_config()?;
|
||||||
if log_level == LevelFilter::Off {
|
if log_level == LevelFilter::Off {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
let encoder = Box::new(PatternEncoder::new(
|
let encoder = Box::new(PatternEncoder::new(
|
||||||
"{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}",
|
"{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}",
|
||||||
));
|
));
|
||||||
let log_filter = match env::var(get_env_name("log_filter")) {
|
let log_filter = env::var(get_env_name("log_filter")).ok();
|
||||||
Ok(v) => Some(v),
|
|
||||||
Err(_) => match is_serve {
|
|
||||||
true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))),
|
|
||||||
false => None,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
match log_path.clone() {
|
match log_path.clone() {
|
||||||
None => {
|
None => {
|
||||||
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
|
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
|
||||||
|
|||||||
+2
-2
@@ -41,8 +41,8 @@ impl Prompt for ReplPrompt {
|
|||||||
PromptHistorySearchStatus::Passing => "",
|
PromptHistorySearchStatus::Passing => "",
|
||||||
PromptHistorySearchStatus::Failing => "failing ",
|
PromptHistorySearchStatus::Failing => "failing ",
|
||||||
};
|
};
|
||||||
// NOTE: magic strings, given there is logic on how these compose I am not sure if it
|
// NOTE: magic strings, given there is logic on how these are composed, I'm unsure if it's
|
||||||
// is worth extracting in to static constant
|
// worth extracting into a static constant
|
||||||
Cow::Owned(format!(
|
Cow::Owned(format!(
|
||||||
"({}reverse-search: {}) ",
|
"({}reverse-search: {}) ",
|
||||||
prefix, history_search.term
|
prefix, history_search.term
|
||||||
|
|||||||
-935
@@ -1,935 +0,0 @@
|
|||||||
use crate::{client::*, config::*, function::*, rag::*, utils::*};
|
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Result};
|
|
||||||
use bytes::Bytes;
|
|
||||||
use chrono::{Timelike, Utc};
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
use http::{Method, Response, StatusCode};
|
|
||||||
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
|
|
||||||
use hyper::{
|
|
||||||
body::{Frame, Incoming},
|
|
||||||
service::service_fn,
|
|
||||||
};
|
|
||||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use std::{
|
|
||||||
convert::Infallible,
|
|
||||||
net::IpAddr,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use tokio::{
|
|
||||||
net::TcpListener,
|
|
||||||
sync::{
|
|
||||||
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
|
|
||||||
oneshot,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use tokio_graceful::Shutdown;
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
|
|
||||||
const DEFAULT_MODEL_NAME: &str = "default";
|
|
||||||
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
|
|
||||||
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
|
|
||||||
|
|
||||||
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
|
|
||||||
|
|
||||||
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
|
|
||||||
let addr = match addr {
|
|
||||||
Some(addr) => {
|
|
||||||
if let Ok(port) = addr.parse::<u16>() {
|
|
||||||
format!("127.0.0.1:{port}")
|
|
||||||
} else if let Ok(ip) = addr.parse::<IpAddr>() {
|
|
||||||
format!("{ip}:8000")
|
|
||||||
} else {
|
|
||||||
addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => config.read().serve_addr(),
|
|
||||||
};
|
|
||||||
let server = Arc::new(Server::new(&config));
|
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
|
||||||
let stop_server = server.run(listener).await?;
|
|
||||||
println!("Chat Completions API: http://{addr}/v1/chat/completions");
|
|
||||||
println!("Embeddings API: http://{addr}/v1/embeddings");
|
|
||||||
println!("Rerank API: http://{addr}/v1/rerank");
|
|
||||||
println!("LLM Playground: http://{addr}/playground");
|
|
||||||
println!("LLM Arena: http://{addr}/arena?num=2");
|
|
||||||
shutdown_signal().await;
|
|
||||||
let _ = stop_server.send(());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Server {
|
|
||||||
config: Config,
|
|
||||||
models: Vec<Value>,
|
|
||||||
roles: Vec<Role>,
|
|
||||||
rags: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Server {
|
|
||||||
fn new(config: &GlobalConfig) -> Self {
|
|
||||||
let mut config = config.read().clone();
|
|
||||||
config.functions = Functions::default();
|
|
||||||
let mut models = list_all_models(&config);
|
|
||||||
let mut default_model = config.model.clone();
|
|
||||||
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
|
|
||||||
models.insert(0, &default_model);
|
|
||||||
let models: Vec<Value> = models
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, model)| {
|
|
||||||
let id = if i == 0 {
|
|
||||||
DEFAULT_MODEL_NAME.into()
|
|
||||||
} else {
|
|
||||||
model.id()
|
|
||||||
};
|
|
||||||
let mut value = json!(model.data());
|
|
||||||
if let Some(value_obj) = value.as_object_mut() {
|
|
||||||
value_obj.insert("id".into(), id.into());
|
|
||||||
value_obj.insert("object".into(), "model".into());
|
|
||||||
value_obj.insert("owned_by".into(), model.client_name().into());
|
|
||||||
value_obj.remove("name");
|
|
||||||
}
|
|
||||||
value
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Self {
|
|
||||||
config,
|
|
||||||
models,
|
|
||||||
roles: Config::all_roles(),
|
|
||||||
rags: Config::list_rags(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
|
|
||||||
let guard = shutdown.guard_weak();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
res = listener.accept() => {
|
|
||||||
let Ok((cnx, _)) = res else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let stream = TokioIo::new(cnx);
|
|
||||||
let server = self.clone();
|
|
||||||
shutdown.spawn_task(async move {
|
|
||||||
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
|
|
||||||
server.clone().handle(request)
|
|
||||||
});
|
|
||||||
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
|
|
||||||
.serve_connection_with_upgrades(stream, hyper_service)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ = guard.cancelled() => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
Ok(tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle(
|
|
||||||
self: Arc<Self>,
|
|
||||||
req: hyper::Request<Incoming>,
|
|
||||||
) -> std::result::Result<AppResponse, hyper::Error> {
|
|
||||||
let method = req.method().clone();
|
|
||||||
let uri = req.uri().clone();
|
|
||||||
let path = uri.path();
|
|
||||||
|
|
||||||
if method == Method::OPTIONS {
|
|
||||||
let mut res = Response::default();
|
|
||||||
*res.status_mut() = StatusCode::NO_CONTENT;
|
|
||||||
set_cors_header(&mut res);
|
|
||||||
return Ok(res);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut status = StatusCode::OK;
|
|
||||||
let res = if path == "/v1/chat/completions" {
|
|
||||||
self.chat_completions(req).await
|
|
||||||
} else if path == "/v1/embeddings" {
|
|
||||||
self.embeddings(req).await
|
|
||||||
} else if path == "/v1/rerank" {
|
|
||||||
self.rerank(req).await
|
|
||||||
} else if path == "/v1/models" {
|
|
||||||
self.list_models()
|
|
||||||
} else if path == "/v1/roles" {
|
|
||||||
self.list_roles()
|
|
||||||
} else if path == "/v1/rags" {
|
|
||||||
self.list_rags()
|
|
||||||
} else if path == "/v1/rags/search" {
|
|
||||||
self.search_rag(req).await
|
|
||||||
} else if path == "/playground" || path == "/playground.html" {
|
|
||||||
self.playground_page()
|
|
||||||
} else if path == "/arena" || path == "/arena.html" {
|
|
||||||
self.arena_page()
|
|
||||||
} else {
|
|
||||||
status = StatusCode::NOT_FOUND;
|
|
||||||
Err(anyhow!("Not Found"))
|
|
||||||
};
|
|
||||||
let mut res = match res {
|
|
||||||
Ok(res) => {
|
|
||||||
info!("{method} {uri} {}", status.as_u16());
|
|
||||||
res
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
if status == StatusCode::OK {
|
|
||||||
status = StatusCode::BAD_REQUEST;
|
|
||||||
}
|
|
||||||
error!("{method} {uri} {} {err}", status.as_u16());
|
|
||||||
ret_err(err)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
*res.status_mut() = status;
|
|
||||||
set_cors_header(&mut res);
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn playground_page(&self) -> Result<AppResponse> {
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "text/html; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn arena_page(&self) -> Result<AppResponse> {
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "text/html; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(ARENA_HTML)).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_models(&self) -> Result<AppResponse> {
|
|
||||||
let data = json!({ "data": self.models });
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_roles(&self) -> Result<AppResponse> {
|
|
||||||
let data = json!({ "data": self.roles });
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_rags(&self) -> Result<AppResponse> {
|
|
||||||
let data = json!({ "data": self.rags });
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn search_rag(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
||||||
let req_body = req.collect().await?.to_bytes();
|
|
||||||
let req_body: Value = serde_json::from_slice(&req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
||||||
|
|
||||||
debug!("search rag request: {req_body}");
|
|
||||||
let SearchRagReqBody { name, input } = serde_json::from_value(req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
||||||
|
|
||||||
let abort_signal = create_abort_signal();
|
|
||||||
|
|
||||||
let rag_path = config.read().rag_file(&name);
|
|
||||||
let rag = Rag::load(&config, &name, &rag_path)?;
|
|
||||||
|
|
||||||
let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?;
|
|
||||||
|
|
||||||
let data = json!({ "data": rag_result });
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json; charset=utf-8")
|
|
||||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
||||||
let req_body = req.collect().await?.to_bytes();
|
|
||||||
let req_body: Value = serde_json::from_slice(&req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
||||||
|
|
||||||
debug!("chat completions request: {req_body}");
|
|
||||||
let req_body = serde_json::from_value(req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let ChatCompletionsReqBody {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
max_tokens,
|
|
||||||
stream,
|
|
||||||
tools,
|
|
||||||
} = req_body;
|
|
||||||
|
|
||||||
let mut messages =
|
|
||||||
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let config = self.config.clone();
|
|
||||||
|
|
||||||
let default_model = config.model.clone();
|
|
||||||
|
|
||||||
let config = Arc::new(RwLock::new(config));
|
|
||||||
|
|
||||||
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
|
|
||||||
(default_model.id(), true)
|
|
||||||
} else if default_model.id() == model {
|
|
||||||
(model, false)
|
|
||||||
} else {
|
|
||||||
(model, true)
|
|
||||||
};
|
|
||||||
|
|
||||||
if change {
|
|
||||||
config.write().set_model(&model_name)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut client = init_client(&config, None)?;
|
|
||||||
if max_tokens.is_some() {
|
|
||||||
client.model_mut().set_max_tokens(max_tokens, true);
|
|
||||||
}
|
|
||||||
let abort_signal = create_abort_signal();
|
|
||||||
let http_client = client.build_client()?;
|
|
||||||
|
|
||||||
let completion_id = generate_completion_id();
|
|
||||||
let created = Utc::now().timestamp();
|
|
||||||
|
|
||||||
patch_messages(&mut messages, client.model());
|
|
||||||
|
|
||||||
let data: ChatCompletionsData = ChatCompletionsData {
|
|
||||||
messages,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
functions,
|
|
||||||
stream,
|
|
||||||
};
|
|
||||||
|
|
||||||
if stream {
|
|
||||||
let (tx, mut rx) = unbounded_channel();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let is_first = Arc::new(AtomicBool::new(true));
|
|
||||||
let (sse_tx, sse_rx) = unbounded_channel();
|
|
||||||
let mut handler = SseHandler::new(sse_tx, abort_signal);
|
|
||||||
async fn map_event(
|
|
||||||
mut sse_rx: UnboundedReceiver<SseEvent>,
|
|
||||||
tx: &UnboundedSender<ResEvent>,
|
|
||||||
is_first: Arc<AtomicBool>,
|
|
||||||
) {
|
|
||||||
while let Some(reply_event) = sse_rx.recv().await {
|
|
||||||
if is_first.load(Ordering::SeqCst) {
|
|
||||||
let _ = tx.send(ResEvent::First(None));
|
|
||||||
is_first.store(false, Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
match reply_event {
|
|
||||||
SseEvent::Text(text) => {
|
|
||||||
let _ = tx.send(ResEvent::Text(text));
|
|
||||||
}
|
|
||||||
SseEvent::Done => {
|
|
||||||
let _ = tx.send(ResEvent::Done);
|
|
||||||
sse_rx.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async fn chat_completions(
|
|
||||||
client: &dyn Client,
|
|
||||||
http_client: &reqwest::Client,
|
|
||||||
handler: &mut SseHandler,
|
|
||||||
mut data: ChatCompletionsData,
|
|
||||||
tx: &UnboundedSender<ResEvent>,
|
|
||||||
is_first: Arc<AtomicBool>,
|
|
||||||
) {
|
|
||||||
if client.model().no_stream() {
|
|
||||||
data.stream = false;
|
|
||||||
let ret = client.chat_completions_inner(http_client, data).await;
|
|
||||||
match ret {
|
|
||||||
Ok(output) => {
|
|
||||||
let ChatCompletionsOutput {
|
|
||||||
text, tool_calls, ..
|
|
||||||
} = output;
|
|
||||||
let _ = tx.send(ResEvent::First(None));
|
|
||||||
is_first.store(false, Ordering::SeqCst);
|
|
||||||
let _ = tx.send(ResEvent::Text(text));
|
|
||||||
if !tool_calls.is_empty() {
|
|
||||||
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
|
|
||||||
is_first.store(false, Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
let ret = client
|
|
||||||
.chat_completions_streaming_inner(http_client, handler, data)
|
|
||||||
.await;
|
|
||||||
let first = match ret {
|
|
||||||
Ok(()) => None,
|
|
||||||
Err(err) => Some(format!("{err:?}")),
|
|
||||||
};
|
|
||||||
if is_first.load(Ordering::SeqCst) {
|
|
||||||
let _ = tx.send(ResEvent::First(first));
|
|
||||||
is_first.store(false, Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
let tool_calls = handler.tool_calls().to_vec();
|
|
||||||
if !tool_calls.is_empty() {
|
|
||||||
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
handler.done();
|
|
||||||
}
|
|
||||||
tokio::join!(
|
|
||||||
map_event(sse_rx, &tx, is_first.clone()),
|
|
||||||
chat_completions(
|
|
||||||
client.as_ref(),
|
|
||||||
&http_client,
|
|
||||||
&mut handler,
|
|
||||||
data,
|
|
||||||
&tx,
|
|
||||||
is_first
|
|
||||||
),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
let first_event = rx.recv().await;
|
|
||||||
|
|
||||||
if let Some(ResEvent::First(Some(err))) = first_event {
|
|
||||||
bail!("{err}");
|
|
||||||
}
|
|
||||||
|
|
||||||
let shared: Arc<(String, String, i64, AtomicBool)> =
|
|
||||||
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
|
|
||||||
let stream = UnboundedReceiverStream::new(rx);
|
|
||||||
let stream = stream.filter_map(move |res_event| {
|
|
||||||
let shared = shared.clone();
|
|
||||||
async move {
|
|
||||||
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
|
|
||||||
match res_event {
|
|
||||||
ResEvent::Text(text) => {
|
|
||||||
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
|
|
||||||
}
|
|
||||||
ResEvent::ToolCalls(tool_calls) => {
|
|
||||||
has_tool_calls.store(true, Ordering::SeqCst);
|
|
||||||
Some(Ok(create_tool_calls_frame(
|
|
||||||
completion_id,
|
|
||||||
model,
|
|
||||||
*created,
|
|
||||||
&tool_calls,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
ResEvent::Done => Some(Ok(create_done_frame(
|
|
||||||
completion_id,
|
|
||||||
model,
|
|
||||||
*created,
|
|
||||||
has_tool_calls.load(Ordering::SeqCst),
|
|
||||||
))),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let res = Response::builder()
|
|
||||||
.status(StatusCode::OK)
|
|
||||||
.header("Content-Type", "text/event-stream")
|
|
||||||
.header("Cache-Control", "no-cache")
|
|
||||||
.header("Connection", "keep-alive")
|
|
||||||
.body(BodyExt::boxed(StreamBody::new(stream)))?;
|
|
||||||
Ok(res)
|
|
||||||
} else {
|
|
||||||
let output = client.chat_completions_inner(&http_client, data).await?;
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(
|
|
||||||
Full::new(ret_non_stream(
|
|
||||||
&completion_id,
|
|
||||||
&model_name,
|
|
||||||
created,
|
|
||||||
&output,
|
|
||||||
))
|
|
||||||
.boxed(),
|
|
||||||
)?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
||||||
let req_body = req.collect().await?.to_bytes();
|
|
||||||
let req_body: Value = serde_json::from_slice(&req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
||||||
|
|
||||||
debug!("embeddings request: {req_body}");
|
|
||||||
let req_body = serde_json::from_value(req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let EmbeddingsReqBody {
|
|
||||||
input,
|
|
||||||
model: embedding_model_id,
|
|
||||||
} = req_body;
|
|
||||||
|
|
||||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
||||||
|
|
||||||
let embedding_model =
|
|
||||||
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
|
|
||||||
|
|
||||||
let texts = match input {
|
|
||||||
EmbeddingsReqBodyInput::Single(v) => vec![v],
|
|
||||||
EmbeddingsReqBodyInput::Multiple(v) => v,
|
|
||||||
};
|
|
||||||
let client = init_client(&config, Some(embedding_model))?;
|
|
||||||
let data = client
|
|
||||||
.embeddings(&EmbeddingsData {
|
|
||||||
query: false,
|
|
||||||
texts,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
let data: Vec<_> = data
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, v)| {
|
|
||||||
json!({
|
|
||||||
"object": "embedding",
|
|
||||||
"embedding": v,
|
|
||||||
"index": i,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let output = json!({
|
|
||||||
"object": "list",
|
|
||||||
"data": data,
|
|
||||||
"model": embedding_model_id,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
||||||
let req_body = req.collect().await?.to_bytes();
|
|
||||||
let req_body: Value = serde_json::from_slice(&req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
||||||
|
|
||||||
debug!("rerank request: {req_body}");
|
|
||||||
let req_body = serde_json::from_value(req_body)
|
|
||||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
||||||
|
|
||||||
let RerankReqBody {
|
|
||||||
model: reranker_model_id,
|
|
||||||
documents,
|
|
||||||
query,
|
|
||||||
top_n,
|
|
||||||
} = req_body;
|
|
||||||
|
|
||||||
let top_n = top_n.unwrap_or(documents.len());
|
|
||||||
|
|
||||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
||||||
|
|
||||||
let reranker_model =
|
|
||||||
Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?;
|
|
||||||
|
|
||||||
let client = init_client(&config, Some(reranker_model))?;
|
|
||||||
let data = client
|
|
||||||
.rerank(&RerankData {
|
|
||||||
query,
|
|
||||||
documents: documents.clone(),
|
|
||||||
top_n,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let results: Vec<_> = data
|
|
||||||
.into_iter()
|
|
||||||
.map(|v| {
|
|
||||||
json!({
|
|
||||||
"index": v.index,
|
|
||||||
"relevance_score": v.relevance_score,
|
|
||||||
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let output = json!({
|
|
||||||
"id": uuid::Uuid::new_v4().to_string(),
|
|
||||||
"results": results,
|
|
||||||
});
|
|
||||||
let res = Response::builder()
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct SearchRagReqBody {
|
|
||||||
name: String,
|
|
||||||
input: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ChatCompletionsReqBody {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<Value>,
|
|
||||||
temperature: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
max_tokens: Option<isize>,
|
|
||||||
#[serde(default)]
|
|
||||||
stream: bool,
|
|
||||||
tools: Option<Vec<Value>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct EmbeddingsReqBody {
|
|
||||||
input: EmbeddingsReqBodyInput,
|
|
||||||
model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum EmbeddingsReqBodyInput {
|
|
||||||
Single(String),
|
|
||||||
Multiple(Vec<String>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct RerankReqBody {
|
|
||||||
documents: Vec<String>,
|
|
||||||
query: String,
|
|
||||||
model: String,
|
|
||||||
top_n: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum ResEvent {
|
|
||||||
First(Option<String>),
|
|
||||||
Text(String),
|
|
||||||
ToolCalls(Vec<ToolCall>),
|
|
||||||
Done,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn shutdown_signal() {
|
|
||||||
tokio::signal::ctrl_c()
|
|
||||||
.await
|
|
||||||
.expect("Failed to install CTRL+C signal handler")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn generate_completion_id() -> String {
|
|
||||||
let random_id = Utc::now().nanosecond();
|
|
||||||
format!("chatcmpl-{random_id}")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_cors_header(res: &mut AppResponse) {
|
|
||||||
res.headers_mut().insert(
|
|
||||||
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
|
||||||
hyper::header::HeaderValue::from_static("*"),
|
|
||||||
);
|
|
||||||
res.headers_mut().insert(
|
|
||||||
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
|
|
||||||
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
|
|
||||||
);
|
|
||||||
res.headers_mut().insert(
|
|
||||||
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
|
|
||||||
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
|
|
||||||
let delta = if content.is_empty() {
|
|
||||||
json!({ "role": "assistant", "content": content })
|
|
||||||
} else {
|
|
||||||
json!({ "content": content })
|
|
||||||
};
|
|
||||||
let choice = json!({
|
|
||||||
"index": 0,
|
|
||||||
"delta": delta,
|
|
||||||
"finish_reason": null,
|
|
||||||
});
|
|
||||||
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
||||||
Frame::data(Bytes::from(format!("data: {value}\n\n")))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_tool_calls_frame(
|
|
||||||
id: &str,
|
|
||||||
model: &str,
|
|
||||||
created: i64,
|
|
||||||
tool_calls: &[ToolCall],
|
|
||||||
) -> Frame<Bytes> {
|
|
||||||
let chunks = tool_calls
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.flat_map(|(i, call)| {
|
|
||||||
let choice1 = json!({
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": null,
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"index": i,
|
|
||||||
"id": call.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": call.name,
|
|
||||||
"arguments": ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"finish_reason": null
|
|
||||||
});
|
|
||||||
let choice2 = json!({
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"index": i,
|
|
||||||
"function": {
|
|
||||||
"arguments": call.arguments.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"finish_reason": null
|
|
||||||
});
|
|
||||||
vec![
|
|
||||||
build_chat_completion_chunk_json(id, model, created, &choice1),
|
|
||||||
build_chat_completion_chunk_json(id, model, created, &choice2),
|
|
||||||
]
|
|
||||||
})
|
|
||||||
.map(|v| format!("data: {v}\n\n"))
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("");
|
|
||||||
Frame::data(Bytes::from(chunks))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
|
|
||||||
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
|
|
||||||
let choice = json!({
|
|
||||||
"index": 0,
|
|
||||||
"delta": {},
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
});
|
|
||||||
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
||||||
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
|
|
||||||
json!({
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [choice],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
|
|
||||||
let id = output.id.as_deref().unwrap_or(id);
|
|
||||||
let input_tokens = output.input_tokens.unwrap_or_default();
|
|
||||||
let output_tokens = output.output_tokens.unwrap_or_default();
|
|
||||||
let total_tokens = input_tokens + output_tokens;
|
|
||||||
let choice = if output.tool_calls.is_empty() {
|
|
||||||
json!({
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": output.text,
|
|
||||||
},
|
|
||||||
"logprobs": null,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
let content = if output.text.is_empty() {
|
|
||||||
Value::Null
|
|
||||||
} else {
|
|
||||||
output.text.clone().into()
|
|
||||||
};
|
|
||||||
let tool_calls: Vec<_> = output
|
|
||||||
.tool_calls
|
|
||||||
.iter()
|
|
||||||
.map(|call| {
|
|
||||||
json!({
|
|
||||||
"id": call.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": call.name,
|
|
||||||
"arguments": call.arguments.to_string(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
json!({
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": content,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
},
|
|
||||||
"logprobs": null,
|
|
||||||
"finish_reason": "tool_calls",
|
|
||||||
})
|
|
||||||
};
|
|
||||||
let res_body = json!({
|
|
||||||
"id": id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [choice],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": input_tokens,
|
|
||||||
"completion_tokens": output_tokens,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
Bytes::from(res_body.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
|
|
||||||
let data = json!({
|
|
||||||
"error": {
|
|
||||||
"message": err.to_string(),
|
|
||||||
"type": "invalid_request_error",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
Response::builder()
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(Full::new(Bytes::from(data.to_string())).boxed())
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
|
|
||||||
let mut output = vec![];
|
|
||||||
let mut tool_results = None;
|
|
||||||
for (i, message) in message.into_iter().enumerate() {
|
|
||||||
let err = || anyhow!("Failed to parse '.messages[{i}]'");
|
|
||||||
let role = message["role"].as_str().ok_or_else(err)?;
|
|
||||||
let content = match message.get("content") {
|
|
||||||
Some(value) => {
|
|
||||||
if let Some(value) = value.as_str() {
|
|
||||||
MessageContent::Text(value.to_string())
|
|
||||||
} else if value.is_array() {
|
|
||||||
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
|
|
||||||
MessageContent::Array(value)
|
|
||||||
} else if value.is_null() {
|
|
||||||
MessageContent::Text(String::new())
|
|
||||||
} else {
|
|
||||||
return Err(err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => MessageContent::Text(String::new()),
|
|
||||||
};
|
|
||||||
match role {
|
|
||||||
"system" | "user" => {
|
|
||||||
let role = match role {
|
|
||||||
"system" => MessageRole::System,
|
|
||||||
"user" => MessageRole::User,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
output.push(Message::new(role, content))
|
|
||||||
}
|
|
||||||
"assistant" => {
|
|
||||||
let role = MessageRole::Assistant;
|
|
||||||
match message["tool_calls"].as_array() {
|
|
||||||
Some(tool_calls) => {
|
|
||||||
if tool_results.is_some() {
|
|
||||||
return Err(err());
|
|
||||||
}
|
|
||||||
let mut list = vec![];
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
if let (id, Some(name), Some(arguments)) = (
|
|
||||||
tool_call["id"].as_str().map(|v| v.to_string()),
|
|
||||||
tool_call["function"]["name"].as_str(),
|
|
||||||
tool_call["function"]["arguments"].as_str(),
|
|
||||||
) {
|
|
||||||
let arguments =
|
|
||||||
serde_json::from_str(arguments).map_err(|_| err())?;
|
|
||||||
list.push((id, name.to_string(), arguments));
|
|
||||||
} else {
|
|
||||||
return Err(err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tool_results = Some((content.to_text(), list, vec![]));
|
|
||||||
}
|
|
||||||
None => output.push(Message::new(role, content)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"tool" => match tool_results.take() {
|
|
||||||
Some((text, tool_calls, mut tool_values)) => {
|
|
||||||
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
|
|
||||||
let content = content.to_text();
|
|
||||||
let value: Value = serde_json::from_str(&content)
|
|
||||||
.ok()
|
|
||||||
.unwrap_or_else(|| content.into());
|
|
||||||
|
|
||||||
tool_values.push((value, tool_call_id));
|
|
||||||
|
|
||||||
if tool_calls.len() == tool_values.len() {
|
|
||||||
let mut list = vec![];
|
|
||||||
for ((id, name, arguments), (value, tool_call_id)) in
|
|
||||||
tool_calls.into_iter().zip(tool_values.into_iter())
|
|
||||||
{
|
|
||||||
if id != tool_call_id {
|
|
||||||
return Err(err());
|
|
||||||
}
|
|
||||||
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
|
|
||||||
}
|
|
||||||
output.push(Message::new(
|
|
||||||
MessageRole::Assistant,
|
|
||||||
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
|
|
||||||
));
|
|
||||||
tool_results = None;
|
|
||||||
} else {
|
|
||||||
tool_results = Some((text, tool_calls, tool_values));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => return Err(err()),
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
return Err(err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_results.is_some() {
|
|
||||||
bail!("Invalid messages");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
|
|
||||||
let tools = match tools {
|
|
||||||
Some(v) => v,
|
|
||||||
None => return Ok(None),
|
|
||||||
};
|
|
||||||
let mut functions = vec![];
|
|
||||||
for (i, tool) in tools.into_iter().enumerate() {
|
|
||||||
if let (Some("function"), Some(function)) = (
|
|
||||||
tool["type"].as_str(),
|
|
||||||
tool["function"]
|
|
||||||
.as_object()
|
|
||||||
.and_then(|v| serde_json::from_value(json!(v)).ok()),
|
|
||||||
) {
|
|
||||||
functions.push(function);
|
|
||||||
} else {
|
|
||||||
bail!("Failed to parse '.tools[{i}]'")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Some(functions))
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
/// Render REPL prompt
|
/// Render REPL prompt
|
||||||
///
|
///
|
||||||
/// The template comprises plain text and `{...}`.
|
/// The template comprises of plain text and `{...}`.
|
||||||
///
|
///
|
||||||
/// The syntax of `{...}`:
|
/// The syntax of `{...}`:
|
||||||
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
|
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
|
||||||
|
|||||||
Reference in New Issue
Block a user