use crate::{client::*, config::*, function::*, rag::*, utils::*}; use anyhow::{anyhow, bail, Result}; use bytes::Bytes; use chrono::{Timelike, Utc}; use futures_util::StreamExt; use http::{Method, Response, StatusCode}; use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; use hyper::{ body::{Frame, Incoming}, service::service_fn, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use serde::Deserialize; use serde_json::{json, Value}; use std::{ convert::Infallible, net::IpAddr, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, }; use tokio::{ net::TcpListener, sync::{ mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, oneshot, }, }; use tokio_graceful::Shutdown; use tokio_stream::wrappers::UnboundedReceiverStream; const DEFAULT_MODEL_NAME: &str = "default"; const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html"); const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html"); type AppResponse = Response>; pub async fn run(config: GlobalConfig, addr: Option) -> Result<()> { let addr = match addr { Some(addr) => { if let Ok(port) = addr.parse::() { format!("127.0.0.1:{port}") } else if let Ok(ip) = addr.parse::() { format!("{ip}:8000") } else { addr } } None => config.read().serve_addr(), }; let server = Arc::new(Server::new(&config)); let listener = TcpListener::bind(&addr).await?; let stop_server = server.run(listener).await?; println!("Chat Completions API: http://{addr}/v1/chat/completions"); println!("Embeddings API: http://{addr}/v1/embeddings"); println!("Rerank API: http://{addr}/v1/rerank"); println!("LLM Playground: http://{addr}/playground"); println!("LLM Arena: http://{addr}/arena?num=2"); shutdown_signal().await; let _ = stop_server.send(()); Ok(()) } struct Server { config: Config, models: Vec, roles: Vec, rags: Vec, } impl Server { fn new(config: &GlobalConfig) -> Self { let mut config = config.read().clone(); config.functions = Functions::default(); let mut models = list_all_models(&config); let mut default_model = config.model.clone(); default_model.data_mut().name = DEFAULT_MODEL_NAME.into(); models.insert(0, &default_model); let models: Vec = models .into_iter() .enumerate() .map(|(i, model)| { let id = if i == 0 { DEFAULT_MODEL_NAME.into() } else { model.id() }; let mut value = json!(model.data()); if let Some(value_obj) = value.as_object_mut() { value_obj.insert("id".into(), id.into()); value_obj.insert("object".into(), "model".into()); value_obj.insert("owned_by".into(), model.client_name().into()); value_obj.remove("name"); } value }) .collect(); Self { config, models, roles: Config::all_roles(), rags: Config::list_rags(), } } async fn run(self: Arc, listener: TcpListener) -> Result> { let (tx, rx) = oneshot::channel(); tokio::spawn(async move { let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); let guard = shutdown.guard_weak(); loop { tokio::select! { res = listener.accept() => { let Ok((cnx, _)) = res else { continue; }; let stream = TokioIo::new(cnx); let server = self.clone(); shutdown.spawn_task(async move { let hyper_service = service_fn(move |request: hyper::Request| { server.clone().handle(request) }); let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(stream, hyper_service) .await; }); } _ = guard.cancelled() => { break; } } } }); Ok(tx) } async fn handle( self: Arc, req: hyper::Request, ) -> std::result::Result { let method = req.method().clone(); let uri = req.uri().clone(); let path = uri.path(); if method == Method::OPTIONS { let mut res = Response::default(); *res.status_mut() = StatusCode::NO_CONTENT; set_cors_header(&mut res); return Ok(res); } let mut status = StatusCode::OK; let res = if path == "/v1/chat/completions" { self.chat_completions(req).await } else if path == "/v1/embeddings" { self.embeddings(req).await } else if path == "/v1/rerank" { self.rerank(req).await } else if path == "/v1/models" { self.list_models() } else if path == "/v1/roles" { self.list_roles() } else if path == "/v1/rags" { self.list_rags() } else if path == "/v1/rags/search" { self.search_rag(req).await } else if path == "/playground" || path == "/playground.html" { self.playground_page() } else if path == "/arena" || path == "/arena.html" { self.arena_page() } else { status = StatusCode::NOT_FOUND; Err(anyhow!("Not Found")) }; let mut res = match res { Ok(res) => { info!("{method} {uri} {}", status.as_u16()); res } Err(err) => { if status == StatusCode::OK { status = StatusCode::BAD_REQUEST; } error!("{method} {uri} {} {err}", status.as_u16()); ret_err(err) } }; *res.status_mut() = status; set_cors_header(&mut res); Ok(res) } fn playground_page(&self) -> Result { let res = Response::builder() .header("Content-Type", "text/html; charset=utf-8") .body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?; Ok(res) } fn arena_page(&self) -> Result { let res = Response::builder() .header("Content-Type", "text/html; charset=utf-8") .body(Full::new(Bytes::from(ARENA_HTML)).boxed())?; Ok(res) } fn list_models(&self) -> Result { let data = json!({ "data": self.models }); let res = Response::builder() .header("Content-Type", "application/json; charset=utf-8") .body(Full::new(Bytes::from(data.to_string())).boxed())?; Ok(res) } fn list_roles(&self) -> Result { let data = json!({ "data": self.roles }); let res = Response::builder() .header("Content-Type", "application/json; charset=utf-8") .body(Full::new(Bytes::from(data.to_string())).boxed())?; Ok(res) } fn list_rags(&self) -> Result { let data = json!({ "data": self.rags }); let res = Response::builder() .header("Content-Type", "application/json; charset=utf-8") .body(Full::new(Bytes::from(data.to_string())).boxed())?; Ok(res) } async fn search_rag(&self, req: hyper::Request) -> Result { let req_body = req.collect().await?.to_bytes(); let req_body: Value = serde_json::from_slice(&req_body) .map_err(|err| anyhow!("Invalid request json, {err}"))?; debug!("search rag request: {req_body}"); let SearchRagReqBody { name, input } = serde_json::from_value(req_body) .map_err(|err| anyhow!("Invalid request body, {err}"))?; let config = Arc::new(RwLock::new(self.config.clone())); let abort_signal = create_abort_signal(); let rag_path = config.read().rag_file(&name); let rag = Rag::load(&config, &name, &rag_path)?; let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?; let data = json!({ "data": rag_result }); let res = Response::builder() .header("Content-Type", "application/json; charset=utf-8") .body(Full::new(Bytes::from(data.to_string())).boxed())?; Ok(res) } async fn chat_completions(&self, req: hyper::Request) -> Result { let req_body = req.collect().await?.to_bytes(); let req_body: Value = serde_json::from_slice(&req_body) .map_err(|err| anyhow!("Invalid request json, {err}"))?; debug!("chat completions request: {req_body}"); let req_body = serde_json::from_value(req_body) .map_err(|err| anyhow!("Invalid request body, {err}"))?; let ChatCompletionsReqBody { model, messages, temperature, top_p, max_tokens, stream, tools, } = req_body; let mut messages = parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?; let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?; let config = self.config.clone(); let default_model = config.model.clone(); let config = Arc::new(RwLock::new(config)); let (model_name, change) = if model == DEFAULT_MODEL_NAME { (default_model.id(), true) } else if default_model.id() == model { (model, false) } else { (model, true) }; if change { config.write().set_model(&model_name)?; } let mut client = init_client(&config, None)?; if max_tokens.is_some() { client.model_mut().set_max_tokens(max_tokens, true); } let abort_signal = create_abort_signal(); let http_client = client.build_client()?; let completion_id = generate_completion_id(); let created = Utc::now().timestamp(); patch_messages(&mut messages, client.model()); let data: ChatCompletionsData = ChatCompletionsData { messages, temperature, top_p, functions, stream, }; if stream { let (tx, mut rx) = unbounded_channel(); tokio::spawn(async move { let is_first = Arc::new(AtomicBool::new(true)); let (sse_tx, sse_rx) = unbounded_channel(); let mut handler = SseHandler::new(sse_tx, abort_signal); async fn map_event( mut sse_rx: UnboundedReceiver, tx: &UnboundedSender, is_first: Arc, ) { while let Some(reply_event) = sse_rx.recv().await { if is_first.load(Ordering::SeqCst) { let _ = tx.send(ResEvent::First(None)); is_first.store(false, Ordering::SeqCst) } match reply_event { SseEvent::Text(text) => { let _ = tx.send(ResEvent::Text(text)); } SseEvent::Done => { let _ = tx.send(ResEvent::Done); sse_rx.close(); } } } } async fn chat_completions( client: &dyn Client, http_client: &reqwest::Client, handler: &mut SseHandler, mut data: ChatCompletionsData, tx: &UnboundedSender, is_first: Arc, ) { if client.model().no_stream() { data.stream = false; let ret = client.chat_completions_inner(http_client, data).await; match ret { Ok(output) => { let ChatCompletionsOutput { text, tool_calls, .. } = output; let _ = tx.send(ResEvent::First(None)); is_first.store(false, Ordering::SeqCst); let _ = tx.send(ResEvent::Text(text)); if !tool_calls.is_empty() { let _ = tx.send(ResEvent::ToolCalls(tool_calls)); } } Err(err) => { let _ = tx.send(ResEvent::First(Some(format!("{err:?}")))); is_first.store(false, Ordering::SeqCst) } }; } else { let ret = client .chat_completions_streaming_inner(http_client, handler, data) .await; let first = match ret { Ok(()) => None, Err(err) => Some(format!("{err:?}")), }; if is_first.load(Ordering::SeqCst) { let _ = tx.send(ResEvent::First(first)); is_first.store(false, Ordering::SeqCst) } let tool_calls = handler.tool_calls().to_vec(); if !tool_calls.is_empty() { let _ = tx.send(ResEvent::ToolCalls(tool_calls)); } } handler.done(); } tokio::join!( map_event(sse_rx, &tx, is_first.clone()), chat_completions( client.as_ref(), &http_client, &mut handler, data, &tx, is_first ), ); }); let first_event = rx.recv().await; if let Some(ResEvent::First(Some(err))) = first_event { bail!("{err}"); } let shared: Arc<(String, String, i64, AtomicBool)> = Arc::new((completion_id, model_name, created, AtomicBool::new(false))); let stream = UnboundedReceiverStream::new(rx); let stream = stream.filter_map(move |res_event| { let shared = shared.clone(); async move { let (completion_id, model, created, has_tool_calls) = shared.as_ref(); match res_event { ResEvent::Text(text) => { Some(Ok(create_text_frame(completion_id, model, *created, &text))) } ResEvent::ToolCalls(tool_calls) => { has_tool_calls.store(true, Ordering::SeqCst); Some(Ok(create_tool_calls_frame( completion_id, model, *created, &tool_calls, ))) } ResEvent::Done => Some(Ok(create_done_frame( completion_id, model, *created, has_tool_calls.load(Ordering::SeqCst), ))), _ => None, } } }); let res = Response::builder() .status(StatusCode::OK) .header("Content-Type", "text/event-stream") .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") .body(BodyExt::boxed(StreamBody::new(stream)))?; Ok(res) } else { let output = client.chat_completions_inner(&http_client, data).await?; let res = Response::builder() .header("Content-Type", "application/json") .body( Full::new(ret_non_stream( &completion_id, &model_name, created, &output, )) .boxed(), )?; Ok(res) } } async fn embeddings(&self, req: hyper::Request) -> Result { let req_body = req.collect().await?.to_bytes(); let req_body: Value = serde_json::from_slice(&req_body) .map_err(|err| anyhow!("Invalid request json, {err}"))?; debug!("embeddings request: {req_body}"); let req_body = serde_json::from_value(req_body) .map_err(|err| anyhow!("Invalid request body, {err}"))?; let EmbeddingsReqBody { input, model: embedding_model_id, } = req_body; let config = Arc::new(RwLock::new(self.config.clone())); let embedding_model = Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; let texts = match input { EmbeddingsReqBodyInput::Single(v) => vec![v], EmbeddingsReqBodyInput::Multiple(v) => v, }; let client = init_client(&config, Some(embedding_model))?; let data = client .embeddings(&EmbeddingsData { query: false, texts, }) .await?; let data: Vec<_> = data .into_iter() .enumerate() .map(|(i, v)| { json!({ "object": "embedding", "embedding": v, "index": i, }) }) .collect(); let output = json!({ "object": "list", "data": data, "model": embedding_model_id, "usage": { "prompt_tokens": 0, "total_tokens": 0, } }); let res = Response::builder() .header("Content-Type", "application/json") .body(Full::new(Bytes::from(output.to_string())).boxed())?; Ok(res) } async fn rerank(&self, req: hyper::Request) -> Result { let req_body = req.collect().await?.to_bytes(); let req_body: Value = serde_json::from_slice(&req_body) .map_err(|err| anyhow!("Invalid request json, {err}"))?; debug!("rerank request: {req_body}"); let req_body = serde_json::from_value(req_body) .map_err(|err| anyhow!("Invalid request body, {err}"))?; let RerankReqBody { model: reranker_model_id, documents, query, top_n, } = req_body; let top_n = top_n.unwrap_or(documents.len()); let config = Arc::new(RwLock::new(self.config.clone())); let reranker_model = Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?; let client = init_client(&config, Some(reranker_model))?; let data = client .rerank(&RerankData { query, documents: documents.clone(), top_n, }) .await?; let results: Vec<_> = data .into_iter() .map(|v| { json!({ "index": v.index, "relevance_score": v.relevance_score, "document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(), }) }) .collect(); let output = json!({ "id": uuid::Uuid::new_v4().to_string(), "results": results, }); let res = Response::builder() .header("Content-Type", "application/json") .body(Full::new(Bytes::from(output.to_string())).boxed())?; Ok(res) } } #[derive(Debug, Deserialize)] struct SearchRagReqBody { name: String, input: String, } #[derive(Debug, Deserialize)] struct ChatCompletionsReqBody { model: String, messages: Vec, temperature: Option, top_p: Option, max_tokens: Option, #[serde(default)] stream: bool, tools: Option>, } #[derive(Debug, Deserialize)] struct EmbeddingsReqBody { input: EmbeddingsReqBodyInput, model: String, } #[derive(Debug, Deserialize)] #[serde(untagged)] enum EmbeddingsReqBodyInput { Single(String), Multiple(Vec), } #[derive(Debug, Deserialize)] struct RerankReqBody { documents: Vec, query: String, model: String, top_n: Option, } #[derive(Debug)] enum ResEvent { First(Option), Text(String), ToolCalls(Vec), Done, } async fn shutdown_signal() { tokio::signal::ctrl_c() .await .expect("Failed to install CTRL+C signal handler") } fn generate_completion_id() -> String { let random_id = Utc::now().nanosecond(); format!("chatcmpl-{random_id}") } fn set_cors_header(res: &mut AppResponse) { res.headers_mut().insert( hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, hyper::header::HeaderValue::from_static("*"), ); res.headers_mut().insert( hyper::header::ACCESS_CONTROL_ALLOW_METHODS, hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"), ); res.headers_mut().insert( hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, hyper::header::HeaderValue::from_static("Content-Type,Authorization"), ); } fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame { let delta = if content.is_empty() { json!({ "role": "assistant", "content": content }) } else { json!({ "content": content }) }; let choice = json!({ "index": 0, "delta": delta, "finish_reason": null, }); let value = build_chat_completion_chunk_json(id, model, created, &choice); Frame::data(Bytes::from(format!("data: {value}\n\n"))) } fn create_tool_calls_frame( id: &str, model: &str, created: i64, tool_calls: &[ToolCall], ) -> Frame { let chunks = tool_calls .iter() .enumerate() .flat_map(|(i, call)| { let choice1 = json!({ "index": 0, "delta": { "role": "assistant", "content": null, "tool_calls": [ { "index": i, "id": call.id, "type": "function", "function": { "name": call.name, "arguments": "" } } ] }, "finish_reason": null }); let choice2 = json!({ "index": 0, "delta": { "tool_calls": [ { "index": i, "function": { "arguments": call.arguments.to_string(), } } ] }, "finish_reason": null }); vec![ build_chat_completion_chunk_json(id, model, created, &choice1), build_chat_completion_chunk_json(id, model, created, &choice2), ] }) .map(|v| format!("data: {v}\n\n")) .collect::>() .join(""); Frame::data(Bytes::from(chunks)) } fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame { let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" }; let choice = json!({ "index": 0, "delta": {}, "finish_reason": finish_reason, }); let value = build_chat_completion_chunk_json(id, model, created, &choice); Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n"))) } fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value { json!({ "id": id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [choice], }) } fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes { let id = output.id.as_deref().unwrap_or(id); let input_tokens = output.input_tokens.unwrap_or_default(); let output_tokens = output.output_tokens.unwrap_or_default(); let total_tokens = input_tokens + output_tokens; let choice = if output.tool_calls.is_empty() { json!({ "index": 0, "message": { "role": "assistant", "content": output.text, }, "logprobs": null, "finish_reason": "stop", }) } else { let content = if output.text.is_empty() { Value::Null } else { output.text.clone().into() }; let tool_calls: Vec<_> = output .tool_calls .iter() .map(|call| { json!({ "id": call.id, "type": "function", "function": { "name": call.name, "arguments": call.arguments.to_string(), } }) }) .collect(); json!({ "index": 0, "message": { "role": "assistant", "content": content, "tool_calls": tool_calls, }, "logprobs": null, "finish_reason": "tool_calls", }) }; let res_body = json!({ "id": id, "object": "chat.completion", "created": created, "model": model, "choices": [choice], "usage": { "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": total_tokens, }, }); Bytes::from(res_body.to_string()) } fn ret_err(err: T) -> AppResponse { let data = json!({ "error": { "message": err.to_string(), "type": "invalid_request_error", }, }); Response::builder() .header("Content-Type", "application/json") .body(Full::new(Bytes::from(data.to_string())).boxed()) .unwrap() } fn parse_messages(message: Vec) -> Result> { let mut output = vec![]; let mut tool_results = None; for (i, message) in message.into_iter().enumerate() { let err = || anyhow!("Failed to parse '.messages[{i}]'"); let role = message["role"].as_str().ok_or_else(err)?; let content = match message.get("content") { Some(value) => { if let Some(value) = value.as_str() { MessageContent::Text(value.to_string()) } else if value.is_array() { let value = serde_json::from_value(value.clone()).map_err(|_| err())?; MessageContent::Array(value) } else if value.is_null() { MessageContent::Text(String::new()) } else { return Err(err()); } } None => MessageContent::Text(String::new()), }; match role { "system" | "user" => { let role = match role { "system" => MessageRole::System, "user" => MessageRole::User, _ => unreachable!(), }; output.push(Message::new(role, content)) } "assistant" => { let role = MessageRole::Assistant; match message["tool_calls"].as_array() { Some(tool_calls) => { if tool_results.is_some() { return Err(err()); } let mut list = vec![]; for tool_call in tool_calls { if let (id, Some(name), Some(arguments)) = ( tool_call["id"].as_str().map(|v| v.to_string()), tool_call["function"]["name"].as_str(), tool_call["function"]["arguments"].as_str(), ) { let arguments = serde_json::from_str(arguments).map_err(|_| err())?; list.push((id, name.to_string(), arguments)); } else { return Err(err()); } } tool_results = Some((content.to_text(), list, vec![])); } None => output.push(Message::new(role, content)), } } "tool" => match tool_results.take() { Some((text, tool_calls, mut tool_values)) => { let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string()); let content = content.to_text(); let value: Value = serde_json::from_str(&content) .ok() .unwrap_or_else(|| content.into()); tool_values.push((value, tool_call_id)); if tool_calls.len() == tool_values.len() { let mut list = vec![]; for ((id, name, arguments), (value, tool_call_id)) in tool_calls.into_iter().zip(tool_values.into_iter()) { if id != tool_call_id { return Err(err()); } list.push(ToolResult::new(ToolCall::new(name, arguments, id), value)) } output.push(Message::new( MessageRole::Assistant, MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)), )); tool_results = None; } else { tool_results = Some((text, tool_calls, tool_values)); } } None => return Err(err()), }, _ => { return Err(err()); } } } if tool_results.is_some() { bail!("Invalid messages"); } Ok(output) } fn parse_tools(tools: Option>) -> Result>> { let tools = match tools { Some(v) => v, None => return Ok(None), }; let mut functions = vec![]; for (i, tool) in tools.into_iter().enumerate() { if let (Some("function"), Some(function)) = ( tool["type"].as_str(), tool["function"] .as_object() .and_then(|v| serde_json::from_value(json!(v)).ok()), ) { functions.push(function); } else { bail!("Failed to parse '.tools[{i}]'") } } Ok(Some(functions)) }