936 lines
32 KiB
Rust
936 lines
32 KiB
Rust
use crate::{client::*, config::*, function::*, rag::*, utils::*};
|
|
|
|
use anyhow::{anyhow, bail, Result};
|
|
use bytes::Bytes;
|
|
use chrono::{Timelike, Utc};
|
|
use futures_util::StreamExt;
|
|
use http::{Method, Response, StatusCode};
|
|
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
|
|
use hyper::{
|
|
body::{Frame, Incoming},
|
|
service::service_fn,
|
|
};
|
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
|
use parking_lot::RwLock;
|
|
use serde::Deserialize;
|
|
use serde_json::{json, Value};
|
|
use std::{
|
|
convert::Infallible,
|
|
net::IpAddr,
|
|
sync::{
|
|
atomic::{AtomicBool, Ordering},
|
|
Arc,
|
|
},
|
|
};
|
|
use tokio::{
|
|
net::TcpListener,
|
|
sync::{
|
|
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
|
|
oneshot,
|
|
},
|
|
};
|
|
use tokio_graceful::Shutdown;
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
|
|
const DEFAULT_MODEL_NAME: &str = "default";
|
|
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
|
|
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
|
|
|
|
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
|
|
|
|
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
|
|
let addr = match addr {
|
|
Some(addr) => {
|
|
if let Ok(port) = addr.parse::<u16>() {
|
|
format!("127.0.0.1:{port}")
|
|
} else if let Ok(ip) = addr.parse::<IpAddr>() {
|
|
format!("{ip}:8000")
|
|
} else {
|
|
addr
|
|
}
|
|
}
|
|
None => config.read().serve_addr(),
|
|
};
|
|
let server = Arc::new(Server::new(&config));
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
let stop_server = server.run(listener).await?;
|
|
println!("Chat Completions API: http://{addr}/v1/chat/completions");
|
|
println!("Embeddings API: http://{addr}/v1/embeddings");
|
|
println!("Rerank API: http://{addr}/v1/rerank");
|
|
println!("LLM Playground: http://{addr}/playground");
|
|
println!("LLM Arena: http://{addr}/arena?num=2");
|
|
shutdown_signal().await;
|
|
let _ = stop_server.send(());
|
|
Ok(())
|
|
}
|
|
|
|
struct Server {
|
|
config: Config,
|
|
models: Vec<Value>,
|
|
roles: Vec<Role>,
|
|
rags: Vec<String>,
|
|
}
|
|
|
|
impl Server {
|
|
fn new(config: &GlobalConfig) -> Self {
|
|
let mut config = config.read().clone();
|
|
config.functions = Functions::default();
|
|
let mut models = list_all_models(&config);
|
|
let mut default_model = config.model.clone();
|
|
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
|
|
models.insert(0, &default_model);
|
|
let models: Vec<Value> = models
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(i, model)| {
|
|
let id = if i == 0 {
|
|
DEFAULT_MODEL_NAME.into()
|
|
} else {
|
|
model.id()
|
|
};
|
|
let mut value = json!(model.data());
|
|
if let Some(value_obj) = value.as_object_mut() {
|
|
value_obj.insert("id".into(), id.into());
|
|
value_obj.insert("object".into(), "model".into());
|
|
value_obj.insert("owned_by".into(), model.client_name().into());
|
|
value_obj.remove("name");
|
|
}
|
|
value
|
|
})
|
|
.collect();
|
|
Self {
|
|
config,
|
|
models,
|
|
roles: Config::all_roles(),
|
|
rags: Config::list_rags(),
|
|
}
|
|
}
|
|
|
|
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
|
|
let (tx, rx) = oneshot::channel();
|
|
tokio::spawn(async move {
|
|
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
|
|
let guard = shutdown.guard_weak();
|
|
|
|
loop {
|
|
tokio::select! {
|
|
res = listener.accept() => {
|
|
let Ok((cnx, _)) = res else {
|
|
continue;
|
|
};
|
|
|
|
let stream = TokioIo::new(cnx);
|
|
let server = self.clone();
|
|
shutdown.spawn_task(async move {
|
|
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
|
|
server.clone().handle(request)
|
|
});
|
|
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
|
|
.serve_connection_with_upgrades(stream, hyper_service)
|
|
.await;
|
|
});
|
|
}
|
|
_ = guard.cancelled() => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
Ok(tx)
|
|
}
|
|
|
|
async fn handle(
|
|
self: Arc<Self>,
|
|
req: hyper::Request<Incoming>,
|
|
) -> std::result::Result<AppResponse, hyper::Error> {
|
|
let method = req.method().clone();
|
|
let uri = req.uri().clone();
|
|
let path = uri.path();
|
|
|
|
if method == Method::OPTIONS {
|
|
let mut res = Response::default();
|
|
*res.status_mut() = StatusCode::NO_CONTENT;
|
|
set_cors_header(&mut res);
|
|
return Ok(res);
|
|
}
|
|
|
|
let mut status = StatusCode::OK;
|
|
let res = if path == "/v1/chat/completions" {
|
|
self.chat_completions(req).await
|
|
} else if path == "/v1/embeddings" {
|
|
self.embeddings(req).await
|
|
} else if path == "/v1/rerank" {
|
|
self.rerank(req).await
|
|
} else if path == "/v1/models" {
|
|
self.list_models()
|
|
} else if path == "/v1/roles" {
|
|
self.list_roles()
|
|
} else if path == "/v1/rags" {
|
|
self.list_rags()
|
|
} else if path == "/v1/rags/search" {
|
|
self.search_rag(req).await
|
|
} else if path == "/playground" || path == "/playground.html" {
|
|
self.playground_page()
|
|
} else if path == "/arena" || path == "/arena.html" {
|
|
self.arena_page()
|
|
} else {
|
|
status = StatusCode::NOT_FOUND;
|
|
Err(anyhow!("Not Found"))
|
|
};
|
|
let mut res = match res {
|
|
Ok(res) => {
|
|
info!("{method} {uri} {}", status.as_u16());
|
|
res
|
|
}
|
|
Err(err) => {
|
|
if status == StatusCode::OK {
|
|
status = StatusCode::BAD_REQUEST;
|
|
}
|
|
error!("{method} {uri} {} {err}", status.as_u16());
|
|
ret_err(err)
|
|
}
|
|
};
|
|
*res.status_mut() = status;
|
|
set_cors_header(&mut res);
|
|
Ok(res)
|
|
}
|
|
|
|
fn playground_page(&self) -> Result<AppResponse> {
|
|
let res = Response::builder()
|
|
.header("Content-Type", "text/html; charset=utf-8")
|
|
.body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
fn arena_page(&self) -> Result<AppResponse> {
|
|
let res = Response::builder()
|
|
.header("Content-Type", "text/html; charset=utf-8")
|
|
.body(Full::new(Bytes::from(ARENA_HTML)).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
fn list_models(&self) -> Result<AppResponse> {
|
|
let data = json!({ "data": self.models });
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json; charset=utf-8")
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
fn list_roles(&self) -> Result<AppResponse> {
|
|
let data = json!({ "data": self.roles });
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json; charset=utf-8")
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
fn list_rags(&self) -> Result<AppResponse> {
|
|
let data = json!({ "data": self.rags });
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json; charset=utf-8")
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
async fn search_rag(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
let req_body = req.collect().await?.to_bytes();
|
|
let req_body: Value = serde_json::from_slice(&req_body)
|
|
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
|
|
debug!("search rag request: {req_body}");
|
|
let SearchRagReqBody { name, input } = serde_json::from_value(req_body)
|
|
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
|
|
let abort_signal = create_abort_signal();
|
|
|
|
let rag_path = config.read().rag_file(&name);
|
|
let rag = Rag::load(&config, &name, &rag_path)?;
|
|
|
|
let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?;
|
|
|
|
let data = json!({ "data": rag_result });
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json; charset=utf-8")
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
let req_body = req.collect().await?.to_bytes();
|
|
let req_body: Value = serde_json::from_slice(&req_body)
|
|
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
|
|
debug!("chat completions request: {req_body}");
|
|
let req_body = serde_json::from_value(req_body)
|
|
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let ChatCompletionsReqBody {
|
|
model,
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
max_tokens,
|
|
stream,
|
|
tools,
|
|
} = req_body;
|
|
|
|
let mut messages =
|
|
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let config = self.config.clone();
|
|
|
|
let default_model = config.model.clone();
|
|
|
|
let config = Arc::new(RwLock::new(config));
|
|
|
|
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
|
|
(default_model.id(), true)
|
|
} else if default_model.id() == model {
|
|
(model, false)
|
|
} else {
|
|
(model, true)
|
|
};
|
|
|
|
if change {
|
|
config.write().set_model(&model_name)?;
|
|
}
|
|
|
|
let mut client = init_client(&config, None)?;
|
|
if max_tokens.is_some() {
|
|
client.model_mut().set_max_tokens(max_tokens, true);
|
|
}
|
|
let abort_signal = create_abort_signal();
|
|
let http_client = client.build_client()?;
|
|
|
|
let completion_id = generate_completion_id();
|
|
let created = Utc::now().timestamp();
|
|
|
|
patch_messages(&mut messages, client.model());
|
|
|
|
let data: ChatCompletionsData = ChatCompletionsData {
|
|
messages,
|
|
temperature,
|
|
top_p,
|
|
functions,
|
|
stream,
|
|
};
|
|
|
|
if stream {
|
|
let (tx, mut rx) = unbounded_channel();
|
|
tokio::spawn(async move {
|
|
let is_first = Arc::new(AtomicBool::new(true));
|
|
let (sse_tx, sse_rx) = unbounded_channel();
|
|
let mut handler = SseHandler::new(sse_tx, abort_signal);
|
|
async fn map_event(
|
|
mut sse_rx: UnboundedReceiver<SseEvent>,
|
|
tx: &UnboundedSender<ResEvent>,
|
|
is_first: Arc<AtomicBool>,
|
|
) {
|
|
while let Some(reply_event) = sse_rx.recv().await {
|
|
if is_first.load(Ordering::SeqCst) {
|
|
let _ = tx.send(ResEvent::First(None));
|
|
is_first.store(false, Ordering::SeqCst)
|
|
}
|
|
match reply_event {
|
|
SseEvent::Text(text) => {
|
|
let _ = tx.send(ResEvent::Text(text));
|
|
}
|
|
SseEvent::Done => {
|
|
let _ = tx.send(ResEvent::Done);
|
|
sse_rx.close();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
async fn chat_completions(
|
|
client: &dyn Client,
|
|
http_client: &reqwest::Client,
|
|
handler: &mut SseHandler,
|
|
mut data: ChatCompletionsData,
|
|
tx: &UnboundedSender<ResEvent>,
|
|
is_first: Arc<AtomicBool>,
|
|
) {
|
|
if client.model().no_stream() {
|
|
data.stream = false;
|
|
let ret = client.chat_completions_inner(http_client, data).await;
|
|
match ret {
|
|
Ok(output) => {
|
|
let ChatCompletionsOutput {
|
|
text, tool_calls, ..
|
|
} = output;
|
|
let _ = tx.send(ResEvent::First(None));
|
|
is_first.store(false, Ordering::SeqCst);
|
|
let _ = tx.send(ResEvent::Text(text));
|
|
if !tool_calls.is_empty() {
|
|
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
|
|
is_first.store(false, Ordering::SeqCst)
|
|
}
|
|
};
|
|
} else {
|
|
let ret = client
|
|
.chat_completions_streaming_inner(http_client, handler, data)
|
|
.await;
|
|
let first = match ret {
|
|
Ok(()) => None,
|
|
Err(err) => Some(format!("{err:?}")),
|
|
};
|
|
if is_first.load(Ordering::SeqCst) {
|
|
let _ = tx.send(ResEvent::First(first));
|
|
is_first.store(false, Ordering::SeqCst)
|
|
}
|
|
let tool_calls = handler.tool_calls().to_vec();
|
|
if !tool_calls.is_empty() {
|
|
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
|
}
|
|
}
|
|
handler.done();
|
|
}
|
|
tokio::join!(
|
|
map_event(sse_rx, &tx, is_first.clone()),
|
|
chat_completions(
|
|
client.as_ref(),
|
|
&http_client,
|
|
&mut handler,
|
|
data,
|
|
&tx,
|
|
is_first
|
|
),
|
|
);
|
|
});
|
|
|
|
let first_event = rx.recv().await;
|
|
|
|
if let Some(ResEvent::First(Some(err))) = first_event {
|
|
bail!("{err}");
|
|
}
|
|
|
|
let shared: Arc<(String, String, i64, AtomicBool)> =
|
|
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
|
|
let stream = UnboundedReceiverStream::new(rx);
|
|
let stream = stream.filter_map(move |res_event| {
|
|
let shared = shared.clone();
|
|
async move {
|
|
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
|
|
match res_event {
|
|
ResEvent::Text(text) => {
|
|
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
|
|
}
|
|
ResEvent::ToolCalls(tool_calls) => {
|
|
has_tool_calls.store(true, Ordering::SeqCst);
|
|
Some(Ok(create_tool_calls_frame(
|
|
completion_id,
|
|
model,
|
|
*created,
|
|
&tool_calls,
|
|
)))
|
|
}
|
|
ResEvent::Done => Some(Ok(create_done_frame(
|
|
completion_id,
|
|
model,
|
|
*created,
|
|
has_tool_calls.load(Ordering::SeqCst),
|
|
))),
|
|
_ => None,
|
|
}
|
|
}
|
|
});
|
|
let res = Response::builder()
|
|
.status(StatusCode::OK)
|
|
.header("Content-Type", "text/event-stream")
|
|
.header("Cache-Control", "no-cache")
|
|
.header("Connection", "keep-alive")
|
|
.body(BodyExt::boxed(StreamBody::new(stream)))?;
|
|
Ok(res)
|
|
} else {
|
|
let output = client.chat_completions_inner(&http_client, data).await?;
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json")
|
|
.body(
|
|
Full::new(ret_non_stream(
|
|
&completion_id,
|
|
&model_name,
|
|
created,
|
|
&output,
|
|
))
|
|
.boxed(),
|
|
)?;
|
|
Ok(res)
|
|
}
|
|
}
|
|
|
|
async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
let req_body = req.collect().await?.to_bytes();
|
|
let req_body: Value = serde_json::from_slice(&req_body)
|
|
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
|
|
debug!("embeddings request: {req_body}");
|
|
let req_body = serde_json::from_value(req_body)
|
|
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let EmbeddingsReqBody {
|
|
input,
|
|
model: embedding_model_id,
|
|
} = req_body;
|
|
|
|
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
|
|
let embedding_model =
|
|
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
|
|
|
|
let texts = match input {
|
|
EmbeddingsReqBodyInput::Single(v) => vec![v],
|
|
EmbeddingsReqBodyInput::Multiple(v) => v,
|
|
};
|
|
let client = init_client(&config, Some(embedding_model))?;
|
|
let data = client
|
|
.embeddings(&EmbeddingsData {
|
|
query: false,
|
|
texts,
|
|
})
|
|
.await?;
|
|
let data: Vec<_> = data
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(i, v)| {
|
|
json!({
|
|
"object": "embedding",
|
|
"embedding": v,
|
|
"index": i,
|
|
})
|
|
})
|
|
.collect();
|
|
let output = json!({
|
|
"object": "list",
|
|
"data": data,
|
|
"model": embedding_model_id,
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
});
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json")
|
|
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
|
|
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
let req_body = req.collect().await?.to_bytes();
|
|
let req_body: Value = serde_json::from_slice(&req_body)
|
|
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
|
|
debug!("rerank request: {req_body}");
|
|
let req_body = serde_json::from_value(req_body)
|
|
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
let RerankReqBody {
|
|
model: reranker_model_id,
|
|
documents,
|
|
query,
|
|
top_n,
|
|
} = req_body;
|
|
|
|
let top_n = top_n.unwrap_or(documents.len());
|
|
|
|
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
|
|
let reranker_model =
|
|
Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?;
|
|
|
|
let client = init_client(&config, Some(reranker_model))?;
|
|
let data = client
|
|
.rerank(&RerankData {
|
|
query,
|
|
documents: documents.clone(),
|
|
top_n,
|
|
})
|
|
.await?;
|
|
|
|
let results: Vec<_> = data
|
|
.into_iter()
|
|
.map(|v| {
|
|
json!({
|
|
"index": v.index,
|
|
"relevance_score": v.relevance_score,
|
|
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
|
|
})
|
|
})
|
|
.collect();
|
|
let output = json!({
|
|
"id": uuid::Uuid::new_v4().to_string(),
|
|
"results": results,
|
|
});
|
|
let res = Response::builder()
|
|
.header("Content-Type", "application/json")
|
|
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
Ok(res)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct SearchRagReqBody {
|
|
name: String,
|
|
input: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatCompletionsReqBody {
|
|
model: String,
|
|
messages: Vec<Value>,
|
|
temperature: Option<f64>,
|
|
top_p: Option<f64>,
|
|
max_tokens: Option<isize>,
|
|
#[serde(default)]
|
|
stream: bool,
|
|
tools: Option<Vec<Value>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct EmbeddingsReqBody {
|
|
input: EmbeddingsReqBodyInput,
|
|
model: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(untagged)]
|
|
enum EmbeddingsReqBodyInput {
|
|
Single(String),
|
|
Multiple(Vec<String>),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct RerankReqBody {
|
|
documents: Vec<String>,
|
|
query: String,
|
|
model: String,
|
|
top_n: Option<usize>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum ResEvent {
|
|
First(Option<String>),
|
|
Text(String),
|
|
ToolCalls(Vec<ToolCall>),
|
|
Done,
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
tokio::signal::ctrl_c()
|
|
.await
|
|
.expect("Failed to install CTRL+C signal handler")
|
|
}
|
|
|
|
fn generate_completion_id() -> String {
|
|
let random_id = Utc::now().nanosecond();
|
|
format!("chatcmpl-{random_id}")
|
|
}
|
|
|
|
fn set_cors_header(res: &mut AppResponse) {
|
|
res.headers_mut().insert(
|
|
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
|
hyper::header::HeaderValue::from_static("*"),
|
|
);
|
|
res.headers_mut().insert(
|
|
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
|
|
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
|
|
);
|
|
res.headers_mut().insert(
|
|
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
|
|
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
|
|
);
|
|
}
|
|
|
|
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
|
|
let delta = if content.is_empty() {
|
|
json!({ "role": "assistant", "content": content })
|
|
} else {
|
|
json!({ "content": content })
|
|
};
|
|
let choice = json!({
|
|
"index": 0,
|
|
"delta": delta,
|
|
"finish_reason": null,
|
|
});
|
|
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
Frame::data(Bytes::from(format!("data: {value}\n\n")))
|
|
}
|
|
|
|
fn create_tool_calls_frame(
|
|
id: &str,
|
|
model: &str,
|
|
created: i64,
|
|
tool_calls: &[ToolCall],
|
|
) -> Frame<Bytes> {
|
|
let chunks = tool_calls
|
|
.iter()
|
|
.enumerate()
|
|
.flat_map(|(i, call)| {
|
|
let choice1 = json!({
|
|
"index": 0,
|
|
"delta": {
|
|
"role": "assistant",
|
|
"content": null,
|
|
"tool_calls": [
|
|
{
|
|
"index": i,
|
|
"id": call.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": call.name,
|
|
"arguments": ""
|
|
}
|
|
}
|
|
]
|
|
},
|
|
"finish_reason": null
|
|
});
|
|
let choice2 = json!({
|
|
"index": 0,
|
|
"delta": {
|
|
"tool_calls": [
|
|
{
|
|
"index": i,
|
|
"function": {
|
|
"arguments": call.arguments.to_string(),
|
|
}
|
|
}
|
|
]
|
|
},
|
|
"finish_reason": null
|
|
});
|
|
vec![
|
|
build_chat_completion_chunk_json(id, model, created, &choice1),
|
|
build_chat_completion_chunk_json(id, model, created, &choice2),
|
|
]
|
|
})
|
|
.map(|v| format!("data: {v}\n\n"))
|
|
.collect::<Vec<String>>()
|
|
.join("");
|
|
Frame::data(Bytes::from(chunks))
|
|
}
|
|
|
|
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
|
|
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
|
|
let choice = json!({
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": finish_reason,
|
|
});
|
|
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
|
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
|
|
}
|
|
|
|
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
|
|
json!({
|
|
"id": id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": [choice],
|
|
})
|
|
}
|
|
|
|
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
|
|
let id = output.id.as_deref().unwrap_or(id);
|
|
let input_tokens = output.input_tokens.unwrap_or_default();
|
|
let output_tokens = output.output_tokens.unwrap_or_default();
|
|
let total_tokens = input_tokens + output_tokens;
|
|
let choice = if output.tool_calls.is_empty() {
|
|
json!({
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": output.text,
|
|
},
|
|
"logprobs": null,
|
|
"finish_reason": "stop",
|
|
})
|
|
} else {
|
|
let content = if output.text.is_empty() {
|
|
Value::Null
|
|
} else {
|
|
output.text.clone().into()
|
|
};
|
|
let tool_calls: Vec<_> = output
|
|
.tool_calls
|
|
.iter()
|
|
.map(|call| {
|
|
json!({
|
|
"id": call.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": call.name,
|
|
"arguments": call.arguments.to_string(),
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
json!({
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": content,
|
|
"tool_calls": tool_calls,
|
|
},
|
|
"logprobs": null,
|
|
"finish_reason": "tool_calls",
|
|
})
|
|
};
|
|
let res_body = json!({
|
|
"id": id,
|
|
"object": "chat.completion",
|
|
"created": created,
|
|
"model": model,
|
|
"choices": [choice],
|
|
"usage": {
|
|
"prompt_tokens": input_tokens,
|
|
"completion_tokens": output_tokens,
|
|
"total_tokens": total_tokens,
|
|
},
|
|
});
|
|
Bytes::from(res_body.to_string())
|
|
}
|
|
|
|
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
|
|
let data = json!({
|
|
"error": {
|
|
"message": err.to_string(),
|
|
"type": "invalid_request_error",
|
|
},
|
|
});
|
|
Response::builder()
|
|
.header("Content-Type", "application/json")
|
|
.body(Full::new(Bytes::from(data.to_string())).boxed())
|
|
.unwrap()
|
|
}
|
|
|
|
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
|
|
let mut output = vec![];
|
|
let mut tool_results = None;
|
|
for (i, message) in message.into_iter().enumerate() {
|
|
let err = || anyhow!("Failed to parse '.messages[{i}]'");
|
|
let role = message["role"].as_str().ok_or_else(err)?;
|
|
let content = match message.get("content") {
|
|
Some(value) => {
|
|
if let Some(value) = value.as_str() {
|
|
MessageContent::Text(value.to_string())
|
|
} else if value.is_array() {
|
|
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
|
|
MessageContent::Array(value)
|
|
} else if value.is_null() {
|
|
MessageContent::Text(String::new())
|
|
} else {
|
|
return Err(err());
|
|
}
|
|
}
|
|
None => MessageContent::Text(String::new()),
|
|
};
|
|
match role {
|
|
"system" | "user" => {
|
|
let role = match role {
|
|
"system" => MessageRole::System,
|
|
"user" => MessageRole::User,
|
|
_ => unreachable!(),
|
|
};
|
|
output.push(Message::new(role, content))
|
|
}
|
|
"assistant" => {
|
|
let role = MessageRole::Assistant;
|
|
match message["tool_calls"].as_array() {
|
|
Some(tool_calls) => {
|
|
if tool_results.is_some() {
|
|
return Err(err());
|
|
}
|
|
let mut list = vec![];
|
|
for tool_call in tool_calls {
|
|
if let (id, Some(name), Some(arguments)) = (
|
|
tool_call["id"].as_str().map(|v| v.to_string()),
|
|
tool_call["function"]["name"].as_str(),
|
|
tool_call["function"]["arguments"].as_str(),
|
|
) {
|
|
let arguments =
|
|
serde_json::from_str(arguments).map_err(|_| err())?;
|
|
list.push((id, name.to_string(), arguments));
|
|
} else {
|
|
return Err(err());
|
|
}
|
|
}
|
|
tool_results = Some((content.to_text(), list, vec![]));
|
|
}
|
|
None => output.push(Message::new(role, content)),
|
|
}
|
|
}
|
|
"tool" => match tool_results.take() {
|
|
Some((text, tool_calls, mut tool_values)) => {
|
|
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
|
|
let content = content.to_text();
|
|
let value: Value = serde_json::from_str(&content)
|
|
.ok()
|
|
.unwrap_or_else(|| content.into());
|
|
|
|
tool_values.push((value, tool_call_id));
|
|
|
|
if tool_calls.len() == tool_values.len() {
|
|
let mut list = vec![];
|
|
for ((id, name, arguments), (value, tool_call_id)) in
|
|
tool_calls.into_iter().zip(tool_values.into_iter())
|
|
{
|
|
if id != tool_call_id {
|
|
return Err(err());
|
|
}
|
|
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
|
|
}
|
|
output.push(Message::new(
|
|
MessageRole::Assistant,
|
|
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
|
|
));
|
|
tool_results = None;
|
|
} else {
|
|
tool_results = Some((text, tool_calls, tool_values));
|
|
}
|
|
}
|
|
None => return Err(err()),
|
|
},
|
|
_ => {
|
|
return Err(err());
|
|
}
|
|
}
|
|
}
|
|
|
|
if tool_results.is_some() {
|
|
bail!("Invalid messages");
|
|
}
|
|
|
|
Ok(output)
|
|
}
|
|
|
|
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
|
|
let tools = match tools {
|
|
Some(v) => v,
|
|
None => return Ok(None),
|
|
};
|
|
let mut functions = vec![];
|
|
for (i, tool) in tools.into_iter().enumerate() {
|
|
if let (Some("function"), Some(function)) = (
|
|
tool["type"].as_str(),
|
|
tool["function"]
|
|
.as_object()
|
|
.and_then(|v| serde_json::from_value(json!(v)).ok()),
|
|
) {
|
|
functions.push(function);
|
|
} else {
|
|
bail!("Failed to parse '.tools[{i}]'")
|
|
}
|
|
}
|
|
Ok(Some(functions))
|
|
}
|