Baseline project
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use indexmap::IndexMap;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static ACCESS_TOKENS: LazyLock<RwLock<IndexMap<String, (String, i64)>>> =
|
||||
LazyLock::new(|| RwLock::new(IndexMap::new()));
|
||||
|
||||
pub fn get_access_token(client_name: &str) -> Result<String> {
|
||||
ACCESS_TOKENS
|
||||
.read()
|
||||
.get(client_name)
|
||||
.map(|(token, _)| token.clone())
|
||||
.ok_or_else(|| anyhow!("Invalid access token"))
|
||||
}
|
||||
|
||||
pub fn is_valid_access_token(client_name: &str) -> bool {
|
||||
let access_tokens = ACCESS_TOKENS.read();
|
||||
let (token, expires_at) = match access_tokens.get(client_name) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
!token.is_empty() && Utc::now().timestamp() < *expires_at
|
||||
}
|
||||
|
||||
pub fn set_access_token(client_name: &str, token: String, expires_at: i64) {
|
||||
let mut access_tokens = ACCESS_TOKENS.write();
|
||||
let entry = access_tokens.entry(client_name.to_string()).or_default();
|
||||
entry.0 = token;
|
||||
entry.1 = expires_at;
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AzureOpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
||||
(
|
||||
"api_base",
|
||||
"API Base",
|
||||
Some("e.g. https://{RESOURCE}.openai.azure.com"),
|
||||
),
|
||||
("api_key", "API Key", None),
|
||||
];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
AzureOpenAIClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &AzureOpenAIClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_base = self_.get_api_base()?;
|
||||
let api_key = self_.get_api_key()?;
|
||||
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2024-12-01-preview",
|
||||
&api_base,
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_base = self_.get_api_base()?;
|
||||
let api_key = self_.get_api_key()?;
|
||||
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/embeddings?api-version=2024-10-21",
|
||||
&api_base,
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
@@ -0,0 +1,643 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use aws_smithy_eventstream::smithy::parse_response_headers;
|
||||
use bytes::BytesMut;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_util::StreamExt;
|
||||
use indexmap::IndexMap;
|
||||
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BedrockConfig {
|
||||
pub name: Option<String>,
|
||||
pub access_key_id: Option<String>,
|
||||
pub secret_access_key: Option<String>,
|
||||
pub region: Option<String>,
|
||||
pub session_token: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl BedrockClient {
|
||||
config_get_fn!(access_key_id, get_access_key_id);
|
||||
config_get_fn!(secret_access_key, get_secret_access_key);
|
||||
config_get_fn!(region, get_region);
|
||||
config_get_fn!(session_token, get_session_token);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 3] = [
|
||||
("access_key_id", "AWS Access Key ID", None),
|
||||
("secret_access_key", "AWS Secret Access Key", None),
|
||||
("region", "AWS Region", None),
|
||||
];
|
||||
|
||||
fn chat_completions_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
let secret_access_key = self.get_secret_access_key()?;
|
||||
let region = self.get_region()?;
|
||||
let session_token = self.get_session_token().ok();
|
||||
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
||||
|
||||
let model_name = &self.model.real_name();
|
||||
|
||||
let uri = if data.stream {
|
||||
format!("/model/{model_name}/converse-stream")
|
||||
} else {
|
||||
format!("/model/{model_name}/converse")
|
||||
};
|
||||
|
||||
let body = build_chat_completions_body(data, &self.model)?;
|
||||
|
||||
let mut request_data = RequestData::new("", body);
|
||||
self.patch_request_data(&mut request_data);
|
||||
let RequestData {
|
||||
url: _,
|
||||
headers,
|
||||
body,
|
||||
} = request_data;
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
&AwsCredentials {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
region,
|
||||
session_token,
|
||||
},
|
||||
AwsRequest {
|
||||
method: Method::POST,
|
||||
host,
|
||||
service: "bedrock".into(),
|
||||
uri,
|
||||
querystring: "".into(),
|
||||
headers,
|
||||
body: body.to_string(),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn embeddings_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
let secret_access_key = self.get_secret_access_key()?;
|
||||
let region = self.get_region()?;
|
||||
let session_token = self.get_session_token().ok();
|
||||
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
||||
|
||||
let uri = format!("/model/{}/invoke", self.model.real_name());
|
||||
|
||||
let input_type = match data.query {
|
||||
true => "search_query",
|
||||
false => "search_document",
|
||||
};
|
||||
|
||||
let body = json!({
|
||||
"texts": data.texts,
|
||||
"input_type": input_type,
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new("", body);
|
||||
self.patch_request_data(&mut request_data);
|
||||
let RequestData {
|
||||
url: _,
|
||||
headers,
|
||||
body,
|
||||
} = request_data;
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
&AwsCredentials {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
region,
|
||||
session_token,
|
||||
},
|
||||
AwsRequest {
|
||||
method: Method::POST,
|
||||
host,
|
||||
service: "bedrock".into(),
|
||||
uri,
|
||||
querystring: "".into(),
|
||||
headers,
|
||||
body: body.to_string(),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Client for BedrockClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let builder = self.chat_completions_builder(client, data)?;
|
||||
chat_completions(builder).await
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
let builder = self.chat_completions_builder(client, data)?;
|
||||
chat_completions_streaming(builder, handler).await
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let builder = self.embeddings_builder(client, data)?;
|
||||
embeddings(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
if !status.is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
catch_error(&data, status.as_u16())?;
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
|
||||
let mut stream = res.bytes_stream();
|
||||
let mut buffer = BytesMut::new();
|
||||
let mut decoder = MessageFrameDecoder::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
buffer.extend_from_slice(&chunk);
|
||||
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
|
||||
let response_headers = parse_response_headers(&message)?;
|
||||
let message_type = response_headers.message_type.as_str();
|
||||
let smithy_type = response_headers.smithy_type.as_str();
|
||||
match (message_type, smithy_type) {
|
||||
("event", _) => {
|
||||
let data: Value = serde_json::from_slice(message.payload())?;
|
||||
debug!("stream-data: {smithy_type} {data}");
|
||||
match smithy_type {
|
||||
"contentBlockStart" => {
|
||||
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
) {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value =
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_arguments.clear();
|
||||
function_name = name.into();
|
||||
function_id = id.into();
|
||||
}
|
||||
}
|
||||
}
|
||||
"contentBlockDelta" => {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) =
|
||||
data["delta"]["reasoningContent"]["text"].as_str()
|
||||
{
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
|
||||
function_arguments.push_str(input);
|
||||
}
|
||||
}
|
||||
"contentBlockStop" => {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
("exception", _) => {
|
||||
let payload = base64_decode(message.payload())?;
|
||||
let data = String::from_utf8_lossy(&payload);
|
||||
|
||||
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
|
||||
}
|
||||
_ => {
|
||||
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
Ok(res_body.embeddings)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream: _,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": [ { "text": strip_think_tag(&text) } ] })]
|
||||
}
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"content": [
|
||||
{
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let content: Vec<_> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
json!({"text": text})
|
||||
}
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
if let Some((mime_type, data)) = url
|
||||
.strip_prefix("data:")
|
||||
.and_then(|v| v.split_once(";base64,"))
|
||||
{
|
||||
json!({
|
||||
"image": {
|
||||
"format": mime_type.replace("image/", ""),
|
||||
"source": {
|
||||
"bytes": data,
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
vec![json!({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut assistant_parts = vec![];
|
||||
let mut user_parts = vec![];
|
||||
if !text.is_empty() {
|
||||
assistant_parts.push(json!({
|
||||
"text": text,
|
||||
}))
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
assistant_parts.push(json!({
|
||||
"toolUse": {
|
||||
"toolUseId": tool_result.call.id,
|
||||
"name": tool_result.call.name,
|
||||
"input": tool_result.call.arguments,
|
||||
}
|
||||
}));
|
||||
user_parts.push(json!({
|
||||
"toolResult": {
|
||||
"toolUseId": tool_result.call.id,
|
||||
"content": [
|
||||
{
|
||||
"json": tool_result.output,
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
}
|
||||
vec![
|
||||
json!({
|
||||
"role": "assistant",
|
||||
"content": assistant_parts,
|
||||
}),
|
||||
json!({
|
||||
"role": "user",
|
||||
"content": user_parts,
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"inferenceConfig": {},
|
||||
"messages": messages,
|
||||
});
|
||||
if let Some(v) = system_message {
|
||||
body["system"] = json!([
|
||||
{
|
||||
"text": v,
|
||||
}
|
||||
])
|
||||
}
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["inferenceConfig"]["maxTokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["inferenceConfig"]["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["inferenceConfig"]["topP"] = v.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
let tools: Vec<_> = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"toolSpec": {
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"inputSchema": {
|
||||
"json": v.parameters,
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
body["toolConfig"] = json!({
|
||||
"tools": tools,
|
||||
})
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = String::new();
|
||||
let mut reasoning = None;
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(array) = data["output"]["message"]["content"].as_array() {
|
||||
for item in array {
|
||||
if let Some(v) = item["text"].as_str() {
|
||||
if !text.is_empty() {
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
text.push_str(v);
|
||||
} else if let Some(reasoning_text) =
|
||||
item["reasoningContent"]["reasoningText"].as_object()
|
||||
{
|
||||
if let Some(text) = json_str_from_map(reasoning_text, "text") {
|
||||
reasoning = Some(text.to_string());
|
||||
}
|
||||
} else if let Some(tool_use) = item["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name), Some(input)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
tool_use.get("input"),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = reasoning {
|
||||
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: None,
|
||||
input_tokens: data["usage"]["inputTokens"].as_u64(),
|
||||
output_tokens: data["usage"]["outputTokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
region: String,
|
||||
session_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsRequest {
|
||||
method: Method,
|
||||
host: String,
|
||||
service: String,
|
||||
uri: String,
|
||||
querystring: String,
|
||||
headers: IndexMap<String, String>,
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn aws_fetch(
|
||||
client: &ReqwestClient,
|
||||
credentials: &AwsCredentials,
|
||||
request: AwsRequest,
|
||||
) -> Result<RequestBuilder> {
|
||||
let AwsRequest {
|
||||
method,
|
||||
host,
|
||||
service,
|
||||
uri,
|
||||
querystring,
|
||||
mut headers,
|
||||
body,
|
||||
} = request;
|
||||
let region = &credentials.region;
|
||||
|
||||
let endpoint = format!("https://{host}{uri}");
|
||||
|
||||
let now: DateTime<Utc> = Utc::now();
|
||||
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
|
||||
let date_stamp = amz_date[0..8].to_string();
|
||||
headers.insert("host".into(), host.clone());
|
||||
headers.insert("x-amz-date".into(), amz_date.clone());
|
||||
if let Some(token) = credentials.session_token.clone() {
|
||||
headers.insert("x-amz-security-token".into(), token);
|
||||
}
|
||||
|
||||
let canonical_headers = headers
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{key}:{value}\n"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
let signed_headers = headers
|
||||
.iter()
|
||||
.map(|(key, _)| key.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(";");
|
||||
|
||||
let payload_hash = sha256(&body);
|
||||
|
||||
let canonical_request = format!(
|
||||
"{}\n{}\n{}\n{}\n{}\n{}",
|
||||
method,
|
||||
encode_uri(&uri),
|
||||
querystring,
|
||||
canonical_headers,
|
||||
signed_headers,
|
||||
payload_hash
|
||||
);
|
||||
|
||||
let algorithm = "AWS4-HMAC-SHA256";
|
||||
let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request");
|
||||
let string_to_sign = format!(
|
||||
"{}\n{}\n{}\n{}",
|
||||
algorithm,
|
||||
amz_date,
|
||||
credential_scope,
|
||||
sha256(&canonical_request)
|
||||
);
|
||||
|
||||
let signing_key = gen_signing_key(
|
||||
&credentials.secret_access_key,
|
||||
&date_stamp,
|
||||
region,
|
||||
&service,
|
||||
);
|
||||
let signature = hmac_sha256(&signing_key, &string_to_sign);
|
||||
let signature = hex_encode(&signature);
|
||||
|
||||
let authorization_header = format!(
|
||||
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
|
||||
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
|
||||
);
|
||||
|
||||
headers.insert("authorization".into(), authorization_header);
|
||||
|
||||
debug!("Request {endpoint} {body}");
|
||||
|
||||
let mut request_builder = client.request(method, endpoint).body(body);
|
||||
|
||||
for (key, value) in &headers {
|
||||
request_builder = request_builder.header(key, value);
|
||||
}
|
||||
|
||||
Ok(request_builder)
|
||||
}
|
||||
|
||||
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
|
||||
let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp);
|
||||
let k_region = hmac_sha256(&k_date, region);
|
||||
let k_service = hmac_sha256(&k_region, service);
|
||||
hmac_sha256(&k_service, "aws4_request")
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::strip_think_tag;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.anthropic.com/v1";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ClaudeConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl ClaudeClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
ClaudeClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
claude_chat_completions,
|
||||
claude_chat_completions_streaming
|
||||
),
|
||||
(noop_prepare_embeddings, noop_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &ClaudeClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/messages", api_base.trim_end_matches('/'));
|
||||
let body = claude_build_chat_completions_body(data, &self_.model)?;
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("anthropic-version", "2023-06-01");
|
||||
request_data.header("x-api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn claude_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
debug!("non-stream-data: {data}");
|
||||
claude_extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
pub async fn claude_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(typ) = data["type"].as_str() {
|
||||
match typ {
|
||||
"content_block_start" => {
|
||||
if let (Some("tool_use"), Some(name), Some(id)) = (
|
||||
data["content_block"]["type"].as_str(),
|
||||
data["content_block"]["name"].as_str(),
|
||||
data["content_block"]["id"].as_str(),
|
||||
) {
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value =
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_name = name.into();
|
||||
function_arguments.clear();
|
||||
function_id = id.into();
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) = data["delta"]["thinking"].as_str() {
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let (true, Some(partial_json)) = (
|
||||
!function_name.is_empty(),
|
||||
data["delta"]["partial_json"].as_str(),
|
||||
) {
|
||||
function_arguments.push_str(partial_json);
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value = if function_arguments.is_empty() {
|
||||
json!({})
|
||||
} else {
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?
|
||||
};
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub fn claude_build_chat_completions_body(
|
||||
data: ChatCompletionsData,
|
||||
model: &Model,
|
||||
) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": strip_think_tag(&text) })]
|
||||
}
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"content": text,
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let content: Vec<_> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
json!({"type": "text", "text": text})
|
||||
}
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
if let Some((mime_type, data)) = url
|
||||
.strip_prefix("data:")
|
||||
.and_then(|v| v.split_once(";base64,"))
|
||||
{
|
||||
json!({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": data,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
vec![json!({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut assistant_parts = vec![];
|
||||
let mut user_parts = vec![];
|
||||
if !text.is_empty() {
|
||||
assistant_parts.push(json!({
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}))
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
assistant_parts.push(json!({
|
||||
"type": "tool_use",
|
||||
"id": tool_result.call.id,
|
||||
"name": tool_result.call.name,
|
||||
"input": tool_result.call.arguments,
|
||||
}));
|
||||
user_parts.push(json!({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_result.call.id,
|
||||
"content": tool_result.output.to_string(),
|
||||
}));
|
||||
}
|
||||
vec![
|
||||
json!({
|
||||
"role": "assistant",
|
||||
"content": assistant_parts,
|
||||
}),
|
||||
json!({
|
||||
"role": "user",
|
||||
"content": user_parts,
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"model": model.real_name(),
|
||||
"messages": messages,
|
||||
});
|
||||
if let Some(v) = system_message {
|
||||
body["system"] = v.into();
|
||||
}
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["max_tokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
if stream {
|
||||
body["stream"] = true.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
body["tools"] = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"input_schema": v.parameters,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = String::new();
|
||||
let mut reasoning = None;
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(list) = data["content"].as_array() {
|
||||
for item in list {
|
||||
match item["type"].as_str() {
|
||||
Some("thinking") => {
|
||||
if let Some(v) = item["thinking"].as_str() {
|
||||
reasoning = Some(v.to_string());
|
||||
}
|
||||
}
|
||||
Some("text") => {
|
||||
if let Some(v) = item["text"].as_str() {
|
||||
if !text.is_empty() {
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
text.push_str(v);
|
||||
}
|
||||
}
|
||||
Some("tool_use") => {
|
||||
if let (Some(name), Some(input), Some(id)) = (
|
||||
item["name"].as_str(),
|
||||
item.get("input"),
|
||||
item["id"].as_str(),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(reasoning) = reasoning {
|
||||
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let output = ChatCompletionsOutput {
|
||||
text: text.to_string(),
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["input_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["output_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
use super::openai::*;
|
||||
use super::openai_compatible::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.cohere.ai/v2";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct CohereConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl CohereClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
CohereClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
chat_completions,
|
||||
chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, embeddings),
|
||||
(prepare_rerank, generic_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &CohereClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/chat", api_base.trim_end_matches('/'));
|
||||
let mut body = openai_build_chat_completions_body(data, &self_.model);
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
if let Some(top_p) = obj.remove("top_p") {
|
||||
obj.insert("p".to_string(), top_p);
|
||||
}
|
||||
}
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/embed", api_base.trim_end_matches('/'));
|
||||
|
||||
let input_type = match data.query {
|
||||
true => "search_query",
|
||||
false => "search_document",
|
||||
};
|
||||
|
||||
let body = json!({
|
||||
"model": self_.model.real_name(),
|
||||
"texts": data.texts,
|
||||
"input_type": input_type,
|
||||
"embedding_types": ["float"],
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/rerank", api_base.trim_end_matches('/'));
|
||||
let body = generic_build_rerank_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
async fn chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
if message.data == "[DONE]" {
|
||||
return Ok(true);
|
||||
}
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(typ) = data["type"].as_str() {
|
||||
match typ {
|
||||
"content-delta" => {
|
||||
if let Some(text) = data["delta"]["message"]["content"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
"tool-plan-delta" => {
|
||||
if let Some(text) = data["delta"]["message"]["tool_plan"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
"tool-call-start" => {
|
||||
if let (Some(function), Some(id)) = (
|
||||
data["delta"]["message"]["tool_calls"]["function"].as_object(),
|
||||
data["delta"]["message"]["tool_calls"]["id"].as_str(),
|
||||
) {
|
||||
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
|
||||
function_name = name.to_string();
|
||||
}
|
||||
function_id = id.to_string();
|
||||
}
|
||||
}
|
||||
"tool-call-delta" => {
|
||||
if let Some(text) =
|
||||
data["delta"]["message"]["tool_calls"]["function"]["arguments"].as_str()
|
||||
{
|
||||
function_arguments.push_str(text);
|
||||
}
|
||||
}
|
||||
"tool-call-end" => {
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_name.clear();
|
||||
function_arguments.clear();
|
||||
function_id.clear();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
Ok(res_body.embeddings.float)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: EmbeddingsResBodyEmbeddings,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbeddings {
|
||||
float: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = data["message"]["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(calls) = data["message"]["tool_calls"].as_array() {
|
||||
if text.is_empty() {
|
||||
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
|
||||
text = tool_plain.to_string();
|
||||
}
|
||||
}
|
||||
for call in calls {
|
||||
if let (Some(name), Some(arguments), Some(id)) = (
|
||||
call["function"]["name"].as_str(),
|
||||
call["function"]["arguments"].as_str(),
|
||||
call["id"].as_str(),
|
||||
) {
|
||||
let arguments: Value = arguments.parse().with_context(|| {
|
||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
||||
})?;
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
arguments,
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,678 @@
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
config::{Config, GlobalConfig, Input},
|
||||
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
|
||||
render::render_stream,
|
||||
utils::*,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use indexmap::IndexMap;
|
||||
use inquire::{
|
||||
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
|
||||
};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::unbounded_channel;
|
||||
|
||||
const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||
|
||||
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
||||
Config::local_models_override()
|
||||
.ok()
|
||||
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
|
||||
});
|
||||
|
||||
static EMBEDDING_MODEL_RE: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap()
|
||||
});
|
||||
|
||||
static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/").unwrap());
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Client: Sync + Send {
|
||||
fn global_config(&self) -> &GlobalConfig;
|
||||
|
||||
fn extra_config(&self) -> Option<&ExtraConfig>;
|
||||
|
||||
fn patch_config(&self) -> Option<&RequestPatch>;
|
||||
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn model(&self) -> &Model;
|
||||
|
||||
fn model_mut(&mut self) -> &mut Model;
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let extra = self.extra_config();
|
||||
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
|
||||
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
|
||||
builder = set_proxy(builder, proxy)?;
|
||||
}
|
||||
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
|
||||
builder = builder.user_agent(user_agent);
|
||||
}
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
return Ok(ChatCompletionsOutput::new(&content));
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = input.prepare_completion_data(self.model(), false)?;
|
||||
self.chat_completions_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to call chat-completions api")
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
&self,
|
||||
input: &Input,
|
||||
handler: &mut SseHandler,
|
||||
) -> Result<()> {
|
||||
let abort_signal = handler.abort();
|
||||
let input = input.clone();
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
handler.text(&content)?;
|
||||
return Ok(());
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = input.prepare_completion_data(self.model(), true)?;
|
||||
self.chat_completions_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done();
|
||||
ret.with_context(|| "Failed to call chat-completions api")
|
||||
}
|
||||
_ = wait_abort_signal(&abort_signal) => {
|
||||
handler.done();
|
||||
Ok(())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
|
||||
let client = self.build_client()?;
|
||||
self.embeddings_inner(&client, data)
|
||||
.await
|
||||
.context("Failed to call embeddings api")
|
||||
}
|
||||
|
||||
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
|
||||
let client = self.build_client()?;
|
||||
self.rerank_inner(&client, data)
|
||||
.await
|
||||
.context("Failed to call rerank api")
|
||||
}
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput>;
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()>;
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
_client: &ReqwestClient,
|
||||
_data: &EmbeddingsData,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
async fn rerank_inner(
|
||||
&self,
|
||||
_client: &ReqwestClient,
|
||||
_data: &RerankData,
|
||||
) -> Result<RerankOutput> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
mut request_data: RequestData,
|
||||
) -> RequestBuilder {
|
||||
self.patch_request_data(&mut request_data);
|
||||
request_data.into_builder(client)
|
||||
}
|
||||
|
||||
fn patch_request_data(&self, request_data: &mut RequestData) {
|
||||
let model_type = self.model().model_type();
|
||||
if let Some(patch) = self.model().patch() {
|
||||
request_data.apply_patch(patch.clone());
|
||||
}
|
||||
|
||||
let patch_map = std::env::var(get_env_name(&format!(
|
||||
"patch_{}_{}",
|
||||
self.model().client_name(),
|
||||
model_type.api_name(),
|
||||
)))
|
||||
.ok()
|
||||
.and_then(|v| serde_json::from_str(&v).ok())
|
||||
.or_else(|| {
|
||||
self.patch_config()
|
||||
.and_then(|v| model_type.extract_patch(v))
|
||||
.cloned()
|
||||
});
|
||||
let patch_map = match patch_map {
|
||||
Some(v) => v,
|
||||
_ => return,
|
||||
};
|
||||
for (key, patch) in patch_map {
|
||||
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
|
||||
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
|
||||
if let Ok(true) = regex.is_match(self.model().name()) {
|
||||
request_data.apply_patch(patch);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self::OpenAIConfig(OpenAIConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ExtraConfig {
|
||||
pub proxy: Option<String>,
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct RequestPatch {
|
||||
pub chat_completions: Option<ApiPatch>,
|
||||
pub embeddings: Option<ApiPatch>,
|
||||
pub rerank: Option<ApiPatch>,
|
||||
}
|
||||
|
||||
pub type ApiPatch = IndexMap<String, Value>;
|
||||
|
||||
pub struct RequestData {
|
||||
pub url: String,
|
||||
pub headers: IndexMap<String, String>,
|
||||
pub body: Value,
|
||||
}
|
||||
|
||||
impl RequestData {
|
||||
pub fn new<T>(url: T, body: Value) -> Self
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
Self {
|
||||
url: url.to_string(),
|
||||
headers: Default::default(),
|
||||
body,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bearer_auth<T>(&mut self, auth: T)
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
self.headers
|
||||
.insert("authorization".into(), format!("Bearer {auth}"));
|
||||
}
|
||||
|
||||
pub fn header<K, V>(&mut self, key: K, value: V)
|
||||
where
|
||||
K: std::fmt::Display,
|
||||
V: std::fmt::Display,
|
||||
{
|
||||
self.headers.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
|
||||
pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder {
|
||||
let RequestData { url, headers, body } = self;
|
||||
debug!("Request {url} {body}");
|
||||
|
||||
let mut builder = client.post(url);
|
||||
for (key, value) in headers {
|
||||
builder = builder.header(key, value);
|
||||
}
|
||||
builder = builder.json(&body);
|
||||
builder
|
||||
}
|
||||
|
||||
pub fn apply_patch(&mut self, patch: Value) {
|
||||
if let Some(patch_url) = patch["url"].as_str() {
|
||||
self.url = patch_url.into();
|
||||
}
|
||||
if let Some(patch_body) = patch.get("body") {
|
||||
json_patch::merge(&mut self.body, patch_body)
|
||||
}
|
||||
if let Some(patch_headers) = patch["headers"].as_object() {
|
||||
for (key, value) in patch_headers {
|
||||
if let Some(value) = value.as_str() {
|
||||
self.header(key, value)
|
||||
} else if value.is_null() {
|
||||
self.headers.swap_remove(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChatCompletionsData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub functions: Option<Vec<FunctionDeclaration>>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsOutput {
|
||||
pub text: String,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
pub id: Option<String>,
|
||||
pub input_tokens: Option<u64>,
|
||||
pub output_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsOutput {
|
||||
pub fn new(text: &str) -> Self {
|
||||
Self {
|
||||
text: text.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingsData {
|
||||
pub texts: Vec<String>,
|
||||
pub query: bool,
|
||||
}
|
||||
|
||||
impl EmbeddingsData {
|
||||
pub fn new(texts: Vec<String>, query: bool) -> Self {
|
||||
Self { texts, query }
|
||||
}
|
||||
}
|
||||
|
||||
pub type EmbeddingsOutput = Vec<Vec<f32>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RerankData {
|
||||
pub query: String,
|
||||
pub documents: Vec<String>,
|
||||
pub top_n: usize,
|
||||
}
|
||||
|
||||
impl RerankData {
|
||||
pub fn new(query: String, documents: Vec<String>, top_n: usize) -> Self {
|
||||
Self {
|
||||
query,
|
||||
documents,
|
||||
top_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type RerankOutput = Vec<RerankResult>;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RerankResult {
|
||||
pub index: usize,
|
||||
pub relevance_score: f64,
|
||||
}
|
||||
|
||||
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>);
|
||||
|
||||
pub async fn create_config(
|
||||
prompts: &[PromptAction<'static>],
|
||||
client: &str,
|
||||
) -> Result<(String, Value)> {
|
||||
let mut config = json!({
|
||||
"type": client,
|
||||
});
|
||||
for (key, desc, help_message) in prompts {
|
||||
let env_name = format!("{client}_{key}").to_ascii_uppercase();
|
||||
let required = std::env::var(&env_name).is_err();
|
||||
let value = prompt_input_string(desc, required, *help_message)?;
|
||||
if !value.is_empty() {
|
||||
config[key] = value.into();
|
||||
}
|
||||
}
|
||||
let model = set_client_models_config(&mut config, client).await?;
|
||||
let clients = json!(vec![config]);
|
||||
Ok((model, clients))
|
||||
}
|
||||
|
||||
pub async fn create_openai_compatible_client_config(
|
||||
client: &str,
|
||||
) -> Result<Option<(String, Value)>> {
|
||||
let api_base = OPENAI_COMPATIBLE_PROVIDERS
|
||||
.into_iter()
|
||||
.find(|(name, _)| client == *name)
|
||||
.map(|(_, api_base)| api_base)
|
||||
.unwrap_or("http(s)://{API_ADDR}/v1");
|
||||
|
||||
let name = if client == OpenAICompatibleClient::NAME {
|
||||
let value = prompt_input_string("Provider Name", true, None)?;
|
||||
value.replace(' ', "-")
|
||||
} else {
|
||||
client.to_string()
|
||||
};
|
||||
|
||||
let mut config = json!({
|
||||
"type": OpenAICompatibleClient::NAME,
|
||||
"name": &name,
|
||||
});
|
||||
|
||||
let api_base = if api_base.contains('{') {
|
||||
prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))?
|
||||
} else {
|
||||
api_base.to_string()
|
||||
};
|
||||
config["api_base"] = api_base.into();
|
||||
|
||||
let api_key = prompt_input_string("API Key", false, None)?;
|
||||
if !api_key.is_empty() {
|
||||
config["api_key"] = api_key.into();
|
||||
}
|
||||
|
||||
let model = set_client_models_config(&mut config, &name).await?;
|
||||
let clients = json!(vec![config]);
|
||||
Ok(Some((model, clients)))
|
||||
}
|
||||
|
||||
pub async fn call_chat_completions(
|
||||
input: &Input,
|
||||
print: bool,
|
||||
extract_code: bool,
|
||||
client: &dyn Client,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let ret = abortable_run_with_spinner(
|
||||
client.chat_completions(input.clone()),
|
||||
"Generating",
|
||||
abort_signal,
|
||||
)
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(ret) => {
|
||||
let ChatCompletionsOutput {
|
||||
mut text,
|
||||
tool_calls,
|
||||
..
|
||||
} = ret;
|
||||
if !text.is_empty() {
|
||||
if extract_code {
|
||||
text = extract_code_block(&strip_think_tag(&text)).to_string();
|
||||
}
|
||||
if print {
|
||||
client.global_config().read().print_markdown(&text)?;
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_chat_completions_streaming(
|
||||
input: &Input,
|
||||
client: &dyn Client,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let (tx, rx) = unbounded_channel();
|
||||
let mut handler = SseHandler::new(tx, abort_signal.clone());
|
||||
|
||||
let (send_ret, render_ret) = tokio::join!(
|
||||
client.chat_completions_streaming(input, &mut handler),
|
||||
render_stream(rx, client.global_config(), abort_signal.clone()),
|
||||
);
|
||||
|
||||
if handler.abort().aborted() {
|
||||
bail!("Aborted.");
|
||||
}
|
||||
|
||||
render_ret?;
|
||||
|
||||
let (text, tool_calls) = handler.take();
|
||||
match send_ret {
|
||||
Ok(_) => {
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
println!();
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
if !text.is_empty() {
|
||||
println!();
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
|
||||
if (200..300).contains(&status) {
|
||||
return Ok(());
|
||||
}
|
||||
debug!("Invalid response, status: {status}, data: {data}");
|
||||
if let Some(error) = data["error"].as_object() {
|
||||
if let (Some(typ), Some(message)) = (
|
||||
json_str_from_map(error, "type"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (type: {typ})");
|
||||
} else if let (Some(typ), Some(message)) = (
|
||||
json_str_from_map(error, "code"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (code: {typ})");
|
||||
}
|
||||
} else if let Some(error) = data["errors"][0].as_object() {
|
||||
if let (Some(code), Some(message)) = (
|
||||
error.get("code").and_then(|v| v.as_u64()),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (status: {code})")
|
||||
}
|
||||
} else if let Some(error) = data[0]["error"].as_object() {
|
||||
if let (Some(status), Some(message)) = (
|
||||
json_str_from_map(error, "status"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (status: {status})")
|
||||
}
|
||||
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
|
||||
{
|
||||
bail!("{detail} (status: {status})");
|
||||
} else if let Some(error) = data["error"].as_str() {
|
||||
bail!("{error}");
|
||||
} else if let Some(message) = data["message"].as_str() {
|
||||
bail!("{message}");
|
||||
}
|
||||
bail!("Invalid response data: {data} (status: {status})");
|
||||
}
|
||||
|
||||
pub fn json_str_from_map<'a>(
|
||||
map: &'a serde_json::Map<String, Value>,
|
||||
field_name: &str,
|
||||
) -> Option<&'a str> {
|
||||
map.get(field_name).and_then(|v| v.as_str())
|
||||
}
|
||||
|
||||
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
|
||||
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
|
||||
let models: Vec<String> = provider
|
||||
.models
|
||||
.iter()
|
||||
.filter(|v| v.model_type == "chat")
|
||||
.map(|v| v.name.clone())
|
||||
.collect();
|
||||
let model_name = select_model(models)?;
|
||||
return Ok(format!("{client}:{model_name}"));
|
||||
}
|
||||
let mut model_names = vec![];
|
||||
if let (Some(true), Some(api_base), api_key) = (
|
||||
client_config["type"]
|
||||
.as_str()
|
||||
.map(|v| v == OpenAICompatibleClient::NAME),
|
||||
client_config["api_base"].as_str(),
|
||||
client_config["api_key"]
|
||||
.as_str()
|
||||
.map(|v| v.to_string())
|
||||
.or_else(|| {
|
||||
let env_name = format!("{client}_api_key").to_ascii_uppercase();
|
||||
std::env::var(&env_name).ok()
|
||||
}),
|
||||
) {
|
||||
match abortable_run_with_spinner(
|
||||
fetch_models(api_base, api_key.as_deref()),
|
||||
"Fetching models",
|
||||
create_abort_signal(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(fetched_models) => {
|
||||
model_names = MultiSelect::new("LLMs to include (required):", fetched_models)
|
||||
.with_validator(|list: &[ListOption<&String>]| {
|
||||
if list.is_empty() {
|
||||
Ok(Validation::Invalid(
|
||||
"At least one item must be selected".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("✗ Fetch models failed: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if model_names.is_empty() {
|
||||
model_names = prompt_input_string(
|
||||
"LLMs to add",
|
||||
true,
|
||||
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
|
||||
)?
|
||||
.split(',')
|
||||
.filter_map(|v| {
|
||||
let v = v.trim();
|
||||
if v.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(v.to_string())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
if model_names.is_empty() {
|
||||
bail!("No models");
|
||||
}
|
||||
let models: Vec<Value> = model_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let l = v.to_lowercase();
|
||||
if l.contains("rank") {
|
||||
json!({
|
||||
"name": v,
|
||||
"type": "reranker",
|
||||
})
|
||||
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
|
||||
json!({
|
||||
"name": v,
|
||||
"type": "embedding",
|
||||
"default_chunk_size": 1000,
|
||||
"max_batch_size": 100
|
||||
})
|
||||
} else if v.contains("vision") {
|
||||
json!({
|
||||
"name": v,
|
||||
"supports_vision": true
|
||||
})
|
||||
} else {
|
||||
json!({
|
||||
"name": v,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
client_config["models"] = models.into();
|
||||
let model_name = select_model(model_names)?;
|
||||
Ok(format!("{client}:{model_name}"))
|
||||
}
|
||||
|
||||
fn select_model(model_names: Vec<String>) -> Result<String> {
|
||||
if model_names.is_empty() {
|
||||
bail!("No models");
|
||||
}
|
||||
let model = if model_names.len() == 1 {
|
||||
model_names[0].clone()
|
||||
} else {
|
||||
Select::new("Default Model (required):", model_names).prompt()?
|
||||
};
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result<String> {
|
||||
let desc = if required {
|
||||
format!("{desc} (required):")
|
||||
} else {
|
||||
format!("{desc} (optional):")
|
||||
};
|
||||
let mut text = Text::new(&desc);
|
||||
if required {
|
||||
text = text.with_validator(required!("This field is required"))
|
||||
}
|
||||
if let Some(help_message) = help_message {
|
||||
text = text.with_help_message(help_message);
|
||||
}
|
||||
let text = text.prompt()?;
|
||||
Ok(text)
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use super::vertexai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct GeminiConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl GeminiClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
GeminiClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
gemini_chat_completions,
|
||||
gemini_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &GeminiClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let func = match data.stream {
|
||||
true => "streamGenerateContent",
|
||||
false => "generateContent",
|
||||
};
|
||||
|
||||
let url = format!(
|
||||
"{}/models/{}:{}",
|
||||
api_base.trim_end_matches('/'),
|
||||
self_.model.real_name(),
|
||||
func
|
||||
);
|
||||
|
||||
let body = gemini_build_chat_completions_body(data, &self_.model)?;
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("x-goog-api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!(
|
||||
"{}/models/{}:batchEmbedContents?key={}",
|
||||
api_base.trim_end_matches('/'),
|
||||
self_.model.real_name(),
|
||||
api_key
|
||||
);
|
||||
|
||||
let model_id = format!("models/{}", self_.model.real_name());
|
||||
|
||||
let requests: Vec<_> = data
|
||||
.texts
|
||||
.iter()
|
||||
.map(|text| {
|
||||
json!({
|
||||
"model": model_id,
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": text
|
||||
}
|
||||
]
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let body = json!({
|
||||
"requests": requests,
|
||||
});
|
||||
|
||||
let request_data = RequestData::new(url, body);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.values)
|
||||
.collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: Vec<EmbeddingsResBodyEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbedding {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
#[macro_export]
|
||||
macro_rules! register_client {
|
||||
(
|
||||
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
|
||||
) => {
|
||||
$(
|
||||
mod $module;
|
||||
)+
|
||||
$(
|
||||
use self::$module::$config;
|
||||
)+
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
$(
|
||||
#[serde(rename = $name)]
|
||||
$config($config),
|
||||
)+
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: $crate::config::GlobalConfig,
|
||||
config: $config,
|
||||
model: $crate::client::Model,
|
||||
}
|
||||
|
||||
impl $client {
|
||||
pub const NAME: &'static str = $name;
|
||||
|
||||
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = global_config.read().clients.iter().find_map(|client_config| {
|
||||
if let ClientConfig::$config(c) = client_config {
|
||||
if Self::name(c) == model.client_name() {
|
||||
return Some(c.clone())
|
||||
}
|
||||
}
|
||||
None
|
||||
})?;
|
||||
|
||||
Some(Box::new(Self {
|
||||
global_config: global_config.clone(),
|
||||
config,
|
||||
model: model.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &$config) -> Vec<Model> {
|
||||
let client_name = Self::name(local_config);
|
||||
if local_config.models.is_empty() {
|
||||
if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| {
|
||||
v.provider == $name ||
|
||||
($name == OpenAICompatibleClient::NAME
|
||||
&& local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default())
|
||||
}) {
|
||||
return Model::from_config(client_name, &v.models);
|
||||
}
|
||||
vec![]
|
||||
} else {
|
||||
Model::from_config(client_name, &local_config.models)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(local_config: &$config) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
}
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
|
||||
let model = model.unwrap_or_else(|| config.read().model.clone());
|
||||
None
|
||||
$(.or_else(|| $client::init(config, &model)))+
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid model '{}'", model.id())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_client_types() -> Vec<&'static str> {
|
||||
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
|
||||
client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name));
|
||||
client_types
|
||||
}
|
||||
|
||||
pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
|
||||
$(
|
||||
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
|
||||
return create_config(&$client::PROMPTS, $client::NAME).await
|
||||
}
|
||||
)+
|
||||
if let Some(ret) = create_openai_compatible_client_config(client).await? {
|
||||
return Ok(ret);
|
||||
}
|
||||
anyhow::bail!("Unknown client '{}'", client)
|
||||
}
|
||||
|
||||
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
|
||||
let names = ALL_CLIENT_NAMES.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.flat_map(|v| match v {
|
||||
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
|
||||
ClientConfig::Unknown => vec![],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
names.iter().collect()
|
||||
}
|
||||
|
||||
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
|
||||
let models = ALL_MODELS.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.flat_map(|v| match v {
|
||||
$(ClientConfig::$config(c) => $client::list_models(c),)+
|
||||
ClientConfig::Unknown => vec![],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
models.iter().collect()
|
||||
}
|
||||
|
||||
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! client_common_fns {
|
||||
() => {
|
||||
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
||||
self.config.extra.as_ref()
|
||||
}
|
||||
|
||||
fn patch_config(&self) -> Option<&$crate::client::RequestPatch> {
|
||||
self.config.patch.as_ref()
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
Self::name(&self.config)
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn model_mut(&mut self) -> &mut Model {
|
||||
&mut self.model
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_client_trait {
|
||||
(
|
||||
$client:ident,
|
||||
($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path),
|
||||
($prepare_embeddings:path, $embeddings:path),
|
||||
($prepare_rerank:path, $rerank:path),
|
||||
) => {
|
||||
#[async_trait::async_trait]
|
||||
impl $crate::client::Client for $crate::client::$client {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: $crate::client::ChatCompletionsData,
|
||||
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
|
||||
let request_data = $prepare_chat_completions(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$chat_completions(builder, self.model()).await
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
handler: &mut $crate::client::SseHandler,
|
||||
data: $crate::client::ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
let request_data = $prepare_chat_completions(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$chat_completions_streaming(builder, handler, self.model()).await
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: &$crate::client::EmbeddingsData,
|
||||
) -> Result<$crate::client::EmbeddingsOutput> {
|
||||
let request_data = $prepare_embeddings(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$embeddings(builder, self.model()).await
|
||||
}
|
||||
|
||||
async fn rerank_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: &$crate::client::RerankData,
|
||||
) -> Result<$crate::client::RerankOutput> {
|
||||
let request_data = $prepare_rerank(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$rerank(builder, self.model()).await
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! config_get_fn {
|
||||
($field_name:ident, $fn_name:ident) => {
|
||||
fn $fn_name(&self) -> anyhow::Result<String> {
|
||||
let env_prefix = Self::name(&self.config);
|
||||
let env_name =
|
||||
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
||||
std::env::var(&env_name)
|
||||
.ok()
|
||||
.or_else(|| self.config.$field_name.clone())
|
||||
.ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unsupported_model {
|
||||
($name:expr) => {
|
||||
anyhow::bail!("Unsupported model '{}'", $name)
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
use super::Model;
|
||||
|
||||
use crate::{function::ToolResult, multiline_text, utils::dimmed_text};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
impl Default for Message {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Text(String::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: MessageRole, content: MessageContent) -> Self {
|
||||
Self { role, content }
|
||||
}
|
||||
|
||||
pub fn merge_system(&mut self, system: MessageContent) {
|
||||
match (&mut self.content, system) {
|
||||
(MessageContent::Text(text), MessageContent::Text(system_text)) => {
|
||||
self.content = MessageContent::Array(vec![
|
||||
MessageContentPart::Text { text: system_text },
|
||||
MessageContentPart::Text {
|
||||
text: text.to_string(),
|
||||
},
|
||||
])
|
||||
}
|
||||
(MessageContent::Array(list), MessageContent::Text(system_text)) => {
|
||||
list.insert(0, MessageContentPart::Text { text: system_text })
|
||||
}
|
||||
(MessageContent::Text(text), MessageContent::Array(mut system_list)) => {
|
||||
system_list.push(MessageContentPart::Text {
|
||||
text: text.to_string(),
|
||||
});
|
||||
self.content = MessageContent::Array(system_list);
|
||||
}
|
||||
(MessageContent::Array(list), MessageContent::Array(mut system_list)) => {
|
||||
system_list.append(list);
|
||||
self.content = MessageContent::Array(system_list);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessageRole {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl MessageRole {
|
||||
pub fn is_system(&self) -> bool {
|
||||
matches!(self, MessageRole::System)
|
||||
}
|
||||
|
||||
pub fn is_user(&self) -> bool {
|
||||
matches!(self, MessageRole::User)
|
||||
}
|
||||
|
||||
pub fn is_assistant(&self) -> bool {
|
||||
matches!(self, MessageRole::Assistant)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Array(Vec<MessageContentPart>),
|
||||
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
|
||||
ToolCalls(MessageContentToolCalls),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn render_input(
|
||||
&self,
|
||||
resolve_url_fn: impl Fn(&str) -> String,
|
||||
agent_info: &Option<(String, Vec<String>)>,
|
||||
) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => multiline_text(text),
|
||||
MessageContent::Array(list) => {
|
||||
let (mut concated_text, mut files) = (String::new(), vec![]);
|
||||
for item in list {
|
||||
match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
concated_text = format!("{concated_text} {text}")
|
||||
}
|
||||
MessageContentPart::ImageUrl { image_url } => {
|
||||
files.push(resolve_url_fn(&image_url.url))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !concated_text.is_empty() {
|
||||
concated_text = format!(" -- {}", multiline_text(&concated_text))
|
||||
}
|
||||
format!(".file {}{}", files.join(" "), concated_text)
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut lines = vec![];
|
||||
if !text.is_empty() {
|
||||
lines.push(text.clone())
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
let mut parts = vec!["Call".to_string()];
|
||||
if let Some((agent_name, functions)) = agent_info {
|
||||
if functions.contains(&tool_result.call.name) {
|
||||
parts.push(agent_name.clone())
|
||||
}
|
||||
}
|
||||
parts.push(tool_result.call.name.clone());
|
||||
parts.push(tool_result.call.arguments.to_string());
|
||||
lines.push(dimmed_text(&parts.join(" ")));
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
|
||||
match self {
|
||||
MessageContent::Text(text) => *text = replace_fn(text),
|
||||
MessageContent::Array(list) => {
|
||||
if list.is_empty() {
|
||||
list.push(MessageContentPart::Text {
|
||||
text: replace_fn(""),
|
||||
})
|
||||
} else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) {
|
||||
*text = replace_fn(text)
|
||||
}
|
||||
}
|
||||
MessageContent::ToolCalls(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.to_string(),
|
||||
MessageContent::Array(list) => {
|
||||
let mut parts = vec![];
|
||||
for item in list {
|
||||
if let MessageContentPart::Text { text } = item {
|
||||
parts.push(text.clone())
|
||||
}
|
||||
}
|
||||
parts.join("\n\n")
|
||||
}
|
||||
MessageContent::ToolCalls(_) => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageContentPart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContentToolCalls {
|
||||
pub tool_results: Vec<ToolResult>,
|
||||
pub text: String,
|
||||
pub sequence: bool,
|
||||
}
|
||||
|
||||
impl MessageContentToolCalls {
|
||||
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
|
||||
Self {
|
||||
tool_results,
|
||||
text,
|
||||
sequence: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge(&mut self, tool_results: Vec<ToolResult>, _text: String) {
|
||||
self.tool_results.extend(tool_results);
|
||||
self.text.clear();
|
||||
self.sequence = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn patch_messages(messages: &mut Vec<Message>, model: &Model) {
|
||||
if messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
if let Some(prefix) = model.system_prompt_prefix() {
|
||||
if messages[0].role.is_system() {
|
||||
messages[0].merge_system(MessageContent::Text(prefix.to_string()));
|
||||
} else {
|
||||
messages.insert(
|
||||
0,
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: MessageContent::Text(prefix.to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
if model.no_system_message() && messages[0].role.is_system() {
|
||||
let system_message = messages.remove(0);
|
||||
if let (Some(message), system) = (messages.get_mut(0), system_message.content) {
|
||||
message.merge_system(system);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_system_message(messages: &mut Vec<Message>) -> Option<String> {
|
||||
if messages[0].role.is_system() {
|
||||
let system_message = messages.remove(0);
|
||||
return Some(system_message.content.to_text());
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
mod access_token;
|
||||
mod common;
|
||||
mod message;
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
mod model;
|
||||
mod stream;
|
||||
|
||||
pub use crate::function::ToolCall;
|
||||
pub use common::*;
|
||||
pub use message::*;
|
||||
pub use model::*;
|
||||
pub use stream::*;
|
||||
|
||||
register_client!(
|
||||
(openai, "openai", OpenAIConfig, OpenAIClient),
|
||||
(
|
||||
openai_compatible,
|
||||
"openai-compatible",
|
||||
OpenAICompatibleConfig,
|
||||
OpenAICompatibleClient
|
||||
),
|
||||
(gemini, "gemini", GeminiConfig, GeminiClient),
|
||||
(claude, "claude", ClaudeConfig, ClaudeClient),
|
||||
(cohere, "cohere", CohereConfig, CohereClient),
|
||||
(
|
||||
azure_openai,
|
||||
"azure-openai",
|
||||
AzureOpenAIConfig,
|
||||
AzureOpenAIClient
|
||||
),
|
||||
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
|
||||
(bedrock, "bedrock", BedrockConfig, BedrockClient),
|
||||
);
|
||||
|
||||
pub const OPENAI_COMPATIBLE_PROVIDERS: [(&str, &str); 18] = [
|
||||
("ai21", "https://api.ai21.com/studio/v1"),
|
||||
(
|
||||
"cloudflare",
|
||||
"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1",
|
||||
),
|
||||
("deepinfra", "https://api.deepinfra.com/v1/openai"),
|
||||
("deepseek", "https://api.deepseek.com"),
|
||||
("ernie", "https://qianfan.baidubce.com/v2"),
|
||||
("github", "https://models.inference.ai.azure.com"),
|
||||
("groq", "https://api.groq.com/openai/v1"),
|
||||
("hunyuan", "https://api.hunyuan.cloud.tencent.com/v1"),
|
||||
("minimax", "https://api.minimax.chat/v1"),
|
||||
("mistral", "https://api.mistral.ai/v1"),
|
||||
("moonshot", "https://api.moonshot.cn/v1"),
|
||||
("openrouter", "https://openrouter.ai/api/v1"),
|
||||
("perplexity", "https://api.perplexity.ai"),
|
||||
(
|
||||
"qianwen",
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
),
|
||||
("xai", "https://api.x.ai/v1"),
|
||||
("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
|
||||
// RAG-dedicated
|
||||
("jina", "https://api.jina.ai/v1"),
|
||||
("voyageai", "https://api.voyageai.com/v1"),
|
||||
];
|
||||
@@ -0,0 +1,407 @@
|
||||
use super::{
|
||||
list_all_models, list_client_names,
|
||||
message::{Message, MessageContent, MessageContentPart},
|
||||
ApiPatch, MessageContentToolCalls, RequestPatch,
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::utils::{estimate_token_length, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::fmt::Display;
|
||||
|
||||
const PER_MESSAGES_TOKENS: usize = 5;
|
||||
const BASIS_TOKENS: usize = 2;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
client_name: String,
|
||||
data: ModelData,
|
||||
}
|
||||
|
||||
impl Default for Model {
|
||||
fn default() -> Self {
|
||||
Model::new("", "")
|
||||
}
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(client_name: &str, name: &str) -> Self {
|
||||
Self {
|
||||
client_name: client_name.into(),
|
||||
data: ModelData::new(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
|
||||
models
|
||||
.iter()
|
||||
.map(|v| Model {
|
||||
client_name: client_name.to_string(),
|
||||
data: v.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
|
||||
let models = list_all_models(config);
|
||||
let (client_name, model_name) = match model_id.split_once(':') {
|
||||
Some((client_name, model_name)) => {
|
||||
if model_name.is_empty() {
|
||||
(client_name, None)
|
||||
} else {
|
||||
(client_name, Some(model_name))
|
||||
}
|
||||
}
|
||||
None => (model_id, None),
|
||||
};
|
||||
match model_name {
|
||||
Some(model_name) => {
|
||||
if let Some(model) = models.iter().find(|v| v.id() == model_id) {
|
||||
if model.model_type() == model_type {
|
||||
return Ok((*model).clone());
|
||||
} else {
|
||||
bail!("Model '{model_id}' is not a {model_type} model")
|
||||
}
|
||||
}
|
||||
if list_client_names(config)
|
||||
.into_iter()
|
||||
.any(|v| *v == client_name)
|
||||
&& model_type.can_create_from_name()
|
||||
{
|
||||
let mut new_model = Self::new(client_name, model_name);
|
||||
new_model.data.model_type = model_type.to_string();
|
||||
return Ok(new_model);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if let Some(found) = models
|
||||
.iter()
|
||||
.find(|v| v.client_name == client_name && v.model_type() == model_type)
|
||||
{
|
||||
return Ok((*found).clone());
|
||||
}
|
||||
}
|
||||
};
|
||||
bail!("Unknown {model_type} model '{model_id}'")
|
||||
}
|
||||
|
||||
pub fn id(&self) -> String {
|
||||
if self.data.name.is_empty() {
|
||||
self.client_name.to_string()
|
||||
} else {
|
||||
format!("{}:{}", self.client_name, self.data.name)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_name(&self) -> &str {
|
||||
&self.client_name
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.data.name
|
||||
}
|
||||
|
||||
pub fn real_name(&self) -> &str {
|
||||
self.data.real_name.as_deref().unwrap_or(&self.data.name)
|
||||
}
|
||||
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
if self.data.model_type.starts_with("embed") {
|
||||
ModelType::Embedding
|
||||
} else if self.data.model_type.starts_with("rerank") {
|
||||
ModelType::Reranker
|
||||
} else {
|
||||
ModelType::Chat
|
||||
}
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &ModelData {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn data_mut(&mut self) -> &mut ModelData {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
pub fn description(&self) -> String {
|
||||
match self.model_type() {
|
||||
ModelType::Chat => {
|
||||
let ModelData {
|
||||
max_input_tokens,
|
||||
max_output_tokens,
|
||||
input_price,
|
||||
output_price,
|
||||
supports_vision,
|
||||
supports_function_calling,
|
||||
..
|
||||
} = &self.data;
|
||||
let max_input_tokens = stringify_option_value(max_input_tokens);
|
||||
let max_output_tokens = stringify_option_value(max_output_tokens);
|
||||
let input_price = stringify_option_value(input_price);
|
||||
let output_price = stringify_option_value(output_price);
|
||||
let mut capabilities = vec![];
|
||||
if *supports_vision {
|
||||
capabilities.push('👁');
|
||||
};
|
||||
if *supports_function_calling {
|
||||
capabilities.push('⚒');
|
||||
};
|
||||
let capabilities: String = capabilities
|
||||
.into_iter()
|
||||
.map(|v| format!("{v} "))
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
format!(
|
||||
"{max_input_tokens:>8} / {max_output_tokens:>8} | {input_price:>6} / {output_price:>6} {capabilities:>6}"
|
||||
)
|
||||
}
|
||||
ModelType::Embedding => {
|
||||
let ModelData {
|
||||
input_price,
|
||||
max_tokens_per_chunk,
|
||||
max_batch_size,
|
||||
..
|
||||
} = &self.data;
|
||||
let max_tokens = stringify_option_value(max_tokens_per_chunk);
|
||||
let max_batch = stringify_option_value(max_batch_size);
|
||||
let price = stringify_option_value(input_price);
|
||||
format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}")
|
||||
}
|
||||
ModelType::Reranker => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn patch(&self) -> Option<&Value> {
|
||||
self.data.patch.as_ref()
|
||||
}
|
||||
|
||||
pub fn max_input_tokens(&self) -> Option<usize> {
|
||||
self.data.max_input_tokens
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> Option<isize> {
|
||||
self.data.max_output_tokens
|
||||
}
|
||||
|
||||
pub fn no_stream(&self) -> bool {
|
||||
self.data.no_stream
|
||||
}
|
||||
|
||||
pub fn no_system_message(&self) -> bool {
|
||||
self.data.no_system_message
|
||||
}
|
||||
|
||||
pub fn system_prompt_prefix(&self) -> Option<&str> {
|
||||
self.data.system_prompt_prefix.as_deref()
|
||||
}
|
||||
|
||||
pub fn max_tokens_per_chunk(&self) -> Option<usize> {
|
||||
self.data.max_tokens_per_chunk
|
||||
}
|
||||
|
||||
pub fn default_chunk_size(&self) -> usize {
|
||||
self.data.default_chunk_size.unwrap_or(1000)
|
||||
}
|
||||
|
||||
pub fn max_batch_size(&self) -> Option<usize> {
|
||||
self.data.max_batch_size
|
||||
}
|
||||
|
||||
pub fn max_tokens_param(&self) -> Option<isize> {
|
||||
if self.data.require_max_tokens {
|
||||
self.data.max_output_tokens
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_max_tokens(
|
||||
&mut self,
|
||||
max_output_tokens: Option<isize>,
|
||||
require_max_tokens: bool,
|
||||
) -> &mut Self {
|
||||
match max_output_tokens {
|
||||
None | Some(0) => self.data.max_output_tokens = None,
|
||||
_ => self.data.max_output_tokens = max_output_tokens,
|
||||
}
|
||||
self.data.require_max_tokens = require_max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
|
||||
let messages_len = messages.len();
|
||||
messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| match &v.content {
|
||||
MessageContent::Text(text) => {
|
||||
if v.role.is_assistant() && i != messages_len - 1 {
|
||||
estimate_token_length(&strip_think_tag(text))
|
||||
} else {
|
||||
estimate_token_length(text)
|
||||
}
|
||||
}
|
||||
MessageContent::Array(list) => list
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
MessageContentPart::Text { text } => estimate_token_length(text),
|
||||
MessageContentPart::ImageUrl { .. } => 0,
|
||||
})
|
||||
.sum(),
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
estimate_token_length(text)
|
||||
+ tool_results
|
||||
.iter()
|
||||
.map(|v| {
|
||||
serde_json::to_string(v)
|
||||
.map(|v| estimate_token_length(&v))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.sum::<usize>()
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn total_tokens(&self, messages: &[Message]) -> usize {
|
||||
if messages.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let num_messages = messages.len();
|
||||
let message_tokens = self.messages_tokens(messages);
|
||||
if messages[num_messages - 1].role.is_user() {
|
||||
num_messages * PER_MESSAGES_TOKENS + message_tokens
|
||||
} else {
|
||||
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
|
||||
}
|
||||
}
|
||||
|
||||
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
|
||||
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
|
||||
if let Some(max_input_tokens) = self.data.max_input_tokens {
|
||||
if total_tokens >= max_input_tokens {
|
||||
bail!("Exceed max_input_tokens limit")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub name: String,
|
||||
#[serde(default = "default_model_type", rename = "type")]
|
||||
pub model_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub real_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_input_tokens: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_price: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_price: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub patch: Option<Value>,
|
||||
|
||||
// chat-only properties
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<isize>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub require_max_tokens: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub supports_vision: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub supports_function_calling: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
no_stream: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
no_system_message: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system_prompt_prefix: Option<String>,
|
||||
|
||||
// embedding-only properties
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens_per_chunk: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default_chunk_size: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl ModelData {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
model_type: default_model_type(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderModels {
|
||||
pub provider: String,
|
||||
pub models: Vec<ModelData>,
|
||||
}
|
||||
|
||||
fn default_model_type() -> String {
|
||||
"chat".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum ModelType {
|
||||
Chat,
|
||||
Embedding,
|
||||
Reranker,
|
||||
}
|
||||
|
||||
impl Display for ModelType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ModelType::Chat => write!(f, "chat"),
|
||||
ModelType::Embedding => write!(f, "embedding"),
|
||||
ModelType::Reranker => write!(f, "reranker"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
pub fn can_create_from_name(self) -> bool {
|
||||
match self {
|
||||
ModelType::Chat => true,
|
||||
ModelType::Embedding => false,
|
||||
ModelType::Reranker => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_name(self) -> &'static str {
|
||||
match self {
|
||||
ModelType::Chat => "chat_completions",
|
||||
ModelType::Embedding => "embeddings",
|
||||
ModelType::Reranker => "rerank",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> {
|
||||
match self {
|
||||
ModelType::Chat => patch.chat_completions.as_ref(),
|
||||
ModelType::Embedding => patch.embeddings.as_ref(),
|
||||
ModelType::Reranker => patch.rerank.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stringify_option_value<T>(value: &Option<T>) -> String
|
||||
where
|
||||
T: Display,
|
||||
{
|
||||
match value {
|
||||
Some(value) => value.to_string(),
|
||||
None => "-".to_string(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::strip_think_tag;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct OpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub organization_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
OpenAIClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &OpenAIClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
if let Some(organization_id) = &self_.config.organization_id {
|
||||
request_data.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{api_base}/embeddings");
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
if let Some(organization_id) = &self_.config.organization_id {
|
||||
request_data.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn openai_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
openai_extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
pub async fn openai_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut call_id = String::new();
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
if message.data == "[DONE]" {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
normalize_function_id(&function_id),
|
||||
))?;
|
||||
}
|
||||
return Ok(true);
|
||||
}
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(text) = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.filter(|v| !v.is_empty())
|
||||
{
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) = data["choices"][0]["delta"]["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| data["choices"][0]["delta"]["reasoning"].as_str())
|
||||
.filter(|v| !v.is_empty())
|
||||
{
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
}
|
||||
if let (Some(function), index, id) = (
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(),
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(),
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["id"]
|
||||
.as_str()
|
||||
.filter(|v| !v.is_empty()),
|
||||
) {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default());
|
||||
if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
normalize_function_id(&function_id),
|
||||
))?;
|
||||
}
|
||||
function_name.clear();
|
||||
function_arguments.clear();
|
||||
function_id.clear();
|
||||
call_id = maybe_call_id;
|
||||
}
|
||||
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
|
||||
if name.starts_with(&function_name) {
|
||||
function_name = name.to_string();
|
||||
} else {
|
||||
function_name.push_str(name);
|
||||
}
|
||||
}
|
||||
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
|
||||
function_arguments.push_str(arguments);
|
||||
}
|
||||
if let Some(id) = id {
|
||||
function_id = id.to_string();
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub async fn openai_embeddings(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
data: Vec<EmbeddingsResBodyEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
|
||||
let ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
} = data;
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results,
|
||||
text: _,
|
||||
sequence,
|
||||
}) => {
|
||||
if !sequence {
|
||||
let tool_calls: Vec<_> = tool_results
|
||||
.iter()
|
||||
.map(|tool_result| {
|
||||
json!({
|
||||
"id": tool_result.call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_result.call.name,
|
||||
"arguments": tool_result.call.arguments.to_string(),
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut messages = vec![
|
||||
json!({ "role": MessageRole::Assistant, "tool_calls": tool_calls }),
|
||||
];
|
||||
for tool_result in tool_results {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"content": tool_result.output.to_string(),
|
||||
"tool_call_id": tool_result.call.id,
|
||||
}));
|
||||
}
|
||||
messages
|
||||
} else {
|
||||
tool_results.into_iter().flat_map(|tool_result| {
|
||||
vec![
|
||||
json!({
|
||||
"role": MessageRole::Assistant,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_result.call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_result.call.name,
|
||||
"arguments": tool_result.call.arguments.to_string(),
|
||||
},
|
||||
}
|
||||
]
|
||||
}),
|
||||
json!({
|
||||
"role": "tool",
|
||||
"content": tool_result.output.to_string(),
|
||||
"tool_call_id": tool_result.call.id,
|
||||
})
|
||||
]
|
||||
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": strip_think_tag(&text) }
|
||||
)]
|
||||
}
|
||||
_ => vec![json!({ "role": role, "content": content })],
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut body = json!({
|
||||
"model": &model.real_name(),
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
if model
|
||||
.patch()
|
||||
.and_then(|v| v.get("body").and_then(|v| v.get("max_tokens")))
|
||||
== Some(&Value::Null)
|
||||
{
|
||||
body["max_completion_tokens"] = v.into();
|
||||
} else {
|
||||
body["max_tokens"] = v.into();
|
||||
}
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
if stream {
|
||||
body["stream"] = true.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
body["tools"] = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": v,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value {
|
||||
json!({
|
||||
"input": data.texts,
|
||||
"model": model.real_name()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let text = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
let reasoning = data["choices"][0]["message"]["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| data["choices"][0]["message"]["reasoning"].as_str())
|
||||
.unwrap_or_default()
|
||||
.trim();
|
||||
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
|
||||
for call in calls {
|
||||
if let (Some(name), Some(arguments), Some(id)) = (
|
||||
call["function"]["name"].as_str(),
|
||||
call["function"]["arguments"].as_str(),
|
||||
call["id"].as_str(),
|
||||
) {
|
||||
let arguments: Value = arguments.parse().with_context(|| {
|
||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
||||
})?;
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
arguments,
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
let text = if !reasoning.is_empty() {
|
||||
format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
} else {
|
||||
text.to_string()
|
||||
};
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["completion_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn normalize_function_id(value: &str) -> Option<String> {
|
||||
if value.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(value.to_string())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAICompatibleConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl OpenAICompatibleClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 0] = [];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
OpenAICompatibleClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(prepare_rerank, generic_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &OpenAICompatibleClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = format!("{api_base}/chat/completions");
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(
|
||||
self_: &OpenAICompatibleClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = format!("{api_base}/embeddings");
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = if self_.name().starts_with("ernie") {
|
||||
format!("{api_base}/rerankers")
|
||||
} else {
|
||||
format!("{api_base}/rerank")
|
||||
};
|
||||
|
||||
let body = generic_build_rerank_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result<String> {
|
||||
let api_base = match self_.get_api_base() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
match OPENAI_COMPATIBLE_PROVIDERS
|
||||
.into_iter()
|
||||
.find_map(|(name, api_base)| {
|
||||
if name == self_.model.client_name() {
|
||||
Some(api_base.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
Some(v) => v,
|
||||
None => return Err(err),
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(api_base.trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let mut data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
if data.get("results").is_none() && data.get("data").is_some() {
|
||||
if let Some(data_obj) = data.as_object_mut() {
|
||||
if let Some(value) = data_obj.remove("data") {
|
||||
data_obj.insert("results".to_string(), value);
|
||||
}
|
||||
}
|
||||
}
|
||||
let res_body: GenericRerankResBody =
|
||||
serde_json::from_value(data).context("Invalid rerank data")?;
|
||||
Ok(res_body.results)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GenericRerankResBody {
|
||||
pub results: RerankOutput,
|
||||
}
|
||||
|
||||
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
|
||||
let RerankData {
|
||||
query,
|
||||
documents,
|
||||
top_n,
|
||||
} = data;
|
||||
|
||||
let mut body = json!({
|
||||
"model": model.real_name(),
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
});
|
||||
if model.client_name().starts_with("voyageai") {
|
||||
body["top_k"] = (*top_n).into()
|
||||
} else {
|
||||
body["top_n"] = (*top_n).into()
|
||||
}
|
||||
body
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
use super::{catch_error, ToolCall};
|
||||
use crate::utils::AbortSignal;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
|
||||
pub struct SseHandler {
|
||||
sender: UnboundedSender<SseEvent>,
|
||||
abort_signal: AbortSignal,
|
||||
buffer: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
impl SseHandler {
|
||||
pub fn new(sender: UnboundedSender<SseEvent>, abort_signal: AbortSignal) -> Self {
|
||||
Self {
|
||||
sender,
|
||||
abort_signal,
|
||||
buffer: String::new(),
|
||||
tool_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text(&mut self, text: &str) -> Result<()> {
|
||||
// debug!("HandleText: {}", text);
|
||||
if text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.buffer.push_str(text);
|
||||
let ret = self
|
||||
.sender
|
||||
.send(SseEvent::Text(text.to_string()))
|
||||
.with_context(|| "Failed to send SseEvent:Text");
|
||||
if let Err(err) = ret {
|
||||
if self.abort_signal.aborted() {
|
||||
return Ok(());
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn done(&mut self) {
|
||||
// debug!("HandleDone");
|
||||
let ret = self.sender.send(SseEvent::Done);
|
||||
if ret.is_err() {
|
||||
if self.abort_signal.aborted() {
|
||||
return;
|
||||
}
|
||||
warn!("Failed to send SseEvent:Done");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
||||
// debug!("HandleCall: {:?}", call);
|
||||
self.tool_calls.push(call);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn abort(&self) -> AbortSignal {
|
||||
self.abort_signal.clone()
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> &[ToolCall] {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
pub fn take(self) -> (String, Vec<ToolCall>) {
|
||||
let Self {
|
||||
buffer, tool_calls, ..
|
||||
} = self;
|
||||
(buffer, tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SseEvent {
|
||||
Text(String),
|
||||
Done,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SseMessage {
|
||||
#[allow(unused)]
|
||||
pub event: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
|
||||
where
|
||||
F: FnMut(SseMessage) -> Result<bool>,
|
||||
{
|
||||
let mut es = builder.eventsource()?;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Open) => {}
|
||||
Ok(Event::Message(message)) => {
|
||||
let message = SseMessage {
|
||||
event: message.event,
|
||||
data: message.data,
|
||||
};
|
||||
if handle(message)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
match err {
|
||||
EventSourceError::StreamEnded => {}
|
||||
EventSourceError::InvalidStatusCode(status, res) => {
|
||||
let text = res.text().await?;
|
||||
let data: Value = match text.parse() {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"Invalid response data: {text} (status: {})",
|
||||
status.as_u16()
|
||||
);
|
||||
}
|
||||
};
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
EventSourceError::InvalidContentType(header_value, res) => {
|
||||
let text = res.text().await?;
|
||||
bail!(
|
||||
"Invalid response event-stream. content-type: {}, data: {text}",
|
||||
header_value.to_str().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
bail!("{}", err);
|
||||
}
|
||||
}
|
||||
es.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn json_stream<S, F, E>(mut stream: S, mut handle: F) -> Result<()>
|
||||
where
|
||||
S: Stream<Item = Result<bytes::Bytes, E>> + Unpin,
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
E: std::error::Error,
|
||||
{
|
||||
let mut parser = JsonStreamParser::default();
|
||||
let mut unparsed_bytes = vec![];
|
||||
while let Some(chunk_bytes) = stream.next().await {
|
||||
let chunk_bytes =
|
||||
chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?;
|
||||
unparsed_bytes.extend(chunk_bytes);
|
||||
match std::str::from_utf8(&unparsed_bytes) {
|
||||
Ok(text) => {
|
||||
parser.process(text, &mut handle)?;
|
||||
unparsed_bytes.clear();
|
||||
}
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !unparsed_bytes.is_empty() {
|
||||
let text = std::str::from_utf8(&unparsed_bytes)?;
|
||||
parser.process(text, &mut handle)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct JsonStreamParser {
|
||||
buffer: Vec<char>,
|
||||
cursor: usize,
|
||||
start: Option<usize>,
|
||||
balances: Vec<char>,
|
||||
quoting: bool,
|
||||
escape: bool,
|
||||
}
|
||||
|
||||
impl JsonStreamParser {
|
||||
fn process<F>(&mut self, text: &str, handle: &mut F) -> Result<()>
|
||||
where
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
{
|
||||
self.buffer.extend(text.chars());
|
||||
|
||||
for i in self.cursor..self.buffer.len() {
|
||||
let ch = self.buffer[i];
|
||||
if self.quoting {
|
||||
if ch == '\\' {
|
||||
self.escape = !self.escape;
|
||||
} else {
|
||||
if !self.escape && ch == '"' {
|
||||
self.quoting = false;
|
||||
}
|
||||
self.escape = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'"' => {
|
||||
self.quoting = true;
|
||||
self.escape = false;
|
||||
}
|
||||
'{' => {
|
||||
if self.balances.is_empty() {
|
||||
self.start = Some(i);
|
||||
}
|
||||
self.balances.push(ch);
|
||||
}
|
||||
'[' => {
|
||||
if self.start.is_some() {
|
||||
self.balances.push(ch);
|
||||
}
|
||||
}
|
||||
'}' => {
|
||||
self.balances.pop();
|
||||
if self.balances.is_empty() {
|
||||
if let Some(start) = self.start.take() {
|
||||
let value: String = self.buffer[start..=i].iter().collect();
|
||||
handle(&value)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
']' => {
|
||||
self.balances.pop();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.cursor = self.buffer.len();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream;
|
||||
use rand::Rng;
|
||||
|
||||
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
|
||||
let mut rng = rand::rng();
|
||||
let len = text.len();
|
||||
let cut1 = rng.random_range(1..len - 1);
|
||||
let cut2 = rng.random_range(cut1 + 1..len);
|
||||
let chunk1 = text.as_bytes()[..cut1].to_vec();
|
||||
let chunk2 = text.as_bytes()[cut1..cut2].to_vec();
|
||||
let chunk3 = text.as_bytes()[cut2..].to_vec();
|
||||
vec![chunk1, chunk2, chunk3]
|
||||
}
|
||||
|
||||
macro_rules! assert_json_stream {
|
||||
($input:expr, $output:expr) => {
|
||||
let chunks: Vec<_> = split_chunks($input)
|
||||
.into_iter()
|
||||
.map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk)))
|
||||
.collect();
|
||||
let stream = stream::iter(chunks);
|
||||
let mut output = vec![];
|
||||
let ret = json_stream(stream, |data| {
|
||||
output.push(data.to_string());
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
assert!(ret.is_ok());
|
||||
assert_eq!($output.replace("\r\n", "\n"), output.join("\n"))
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_stream_ndjson() {
|
||||
let data = r#"{"key": "value"}
|
||||
{"key": "value2"}
|
||||
{"key": "value3"}"#;
|
||||
assert_json_stream!(data, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_stream_array() {
|
||||
let input = r#"[
|
||||
{"key": "value"},
|
||||
{"key": "value2"},
|
||||
{"key": "value3"},"#;
|
||||
let output = r#"{"key": "value"}
|
||||
{"key": "value2"}
|
||||
{"key": "value3"}"#;
|
||||
assert_json_stream!(input, output);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,537 @@
|
||||
use super::access_token::*;
|
||||
use super::claude::*;
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use chrono::{Duration, Utc};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::{path::PathBuf, str::FromStr};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct VertexAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub project_id: Option<String>,
|
||||
pub location: Option<String>,
|
||||
pub adc_file: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl VertexAIClient {
|
||||
config_get_fn!(project_id, get_project_id);
|
||||
config_get_fn!(location, get_location);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
||||
("project_id", "Project ID", None),
|
||||
("location", "Location", None),
|
||||
];
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Client for VertexAIClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let model = self.model();
|
||||
let model_category = ModelCategory::from_str(model.real_name())?;
|
||||
let request_data = prepare_chat_completions(self, data, &model_category)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
match model_category {
|
||||
ModelCategory::Gemini => gemini_chat_completions(builder, model).await,
|
||||
ModelCategory::Claude => claude_chat_completions(builder, model).await,
|
||||
ModelCategory::Mistral => openai_chat_completions(builder, model).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let model = self.model();
|
||||
let model_category = ModelCategory::from_str(model.real_name())?;
|
||||
let request_data = prepare_chat_completions(self, data, &model_category)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
match model_category {
|
||||
ModelCategory::Gemini => {
|
||||
gemini_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
ModelCategory::Claude => {
|
||||
claude_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
openai_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let request_data = prepare_embeddings(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
embeddings(builder, self.model()).await
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &VertexAIClient,
|
||||
data: ChatCompletionsData,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<RequestData> {
|
||||
let project_id = self_.get_project_id()?;
|
||||
let location = self_.get_location()?;
|
||||
let access_token = get_access_token(self_.name())?;
|
||||
|
||||
let base_url = if location == "global" {
|
||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
||||
} else {
|
||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
||||
};
|
||||
|
||||
let model_name = self_.model.real_name();
|
||||
|
||||
let url = match model_category {
|
||||
ModelCategory::Gemini => {
|
||||
let func = match data.stream {
|
||||
true => "streamGenerateContent",
|
||||
false => "generateContent",
|
||||
};
|
||||
format!("{base_url}/google/models/{model_name}:{func}")
|
||||
}
|
||||
ModelCategory::Claude => {
|
||||
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
let func = match data.stream {
|
||||
true => "streamRawPredict",
|
||||
false => "rawPredict",
|
||||
};
|
||||
format!("{base_url}/mistralai/models/{model_name}:{func}")
|
||||
}
|
||||
};
|
||||
|
||||
let body = match model_category {
|
||||
ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?,
|
||||
ModelCategory::Claude => {
|
||||
let mut body = claude_build_chat_completions_body(data, &self_.model)?;
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj.remove("model");
|
||||
}
|
||||
body["anthropic_version"] = "vertex-2023-10-16".into();
|
||||
body
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
let mut body = openai_build_chat_completions_body(data, &self_.model);
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj["model"] = strip_model_version(self_.model.real_name()).into();
|
||||
}
|
||||
body
|
||||
}
|
||||
};
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(access_token);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let project_id = self_.get_project_id()?;
|
||||
let location = self_.get_location()?;
|
||||
let access_token = get_access_token(self_.name())?;
|
||||
|
||||
let base_url = if location == "global" {
|
||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
||||
} else {
|
||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
||||
};
|
||||
let url = format!(
|
||||
"{base_url}/google/models/{}:predict",
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect();
|
||||
|
||||
let body = json!({
|
||||
"instances": instances,
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(access_token);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn gemini_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
debug!("non-stream-data: {data}");
|
||||
gemini_extract_chat_completions_text(&data)
|
||||
}
|
||||
|
||||
pub async fn gemini_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
if !status.is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
catch_error(&data, status.as_u16())?;
|
||||
} else {
|
||||
let handle = |value: &str| -> Result<()> {
|
||||
let data: Value = serde_json::from_str(value)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
if i > 0 {
|
||||
handler.text("\n\n")?;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let (Some(name), Some(args)) = (
|
||||
part["functionCall"]["name"].as_str(),
|
||||
part["functionCall"]["args"].as_object(),
|
||||
) {
|
||||
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
|
||||
}
|
||||
}
|
||||
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
|
||||
.as_str()
|
||||
.or_else(|| data["candidates"][0]["finishReason"].as_str())
|
||||
{
|
||||
bail!("Blocked due to safety")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
};
|
||||
json_stream(res.bytes_stream(), handle).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body
|
||||
.predictions
|
||||
.into_iter()
|
||||
.map(|v| v.embeddings.values)
|
||||
.collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
predictions: Vec<EmbeddingsResBodyPrediction>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyPrediction {
|
||||
embeddings: EmbeddingsResBodyPredictionEmbeddings,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyPredictionEmbeddings {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
|
||||
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text_parts = vec![];
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
|
||||
for part in parts {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
text_parts.push(text);
|
||||
}
|
||||
if let (Some(name), Some(args)) = (
|
||||
part["functionCall"]["name"].as_str(),
|
||||
part["functionCall"]["args"].as_object(),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = text_parts.join("\n\n");
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
|
||||
.as_str()
|
||||
.or_else(|| data["candidates"][0]["finishReason"].as_str())
|
||||
{
|
||||
bail!("Blocked due to safety")
|
||||
} else {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
}
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: None,
|
||||
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
|
||||
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn gemini_build_chat_completions_body(
|
||||
data: ChatCompletionsData,
|
||||
model: &Model,
|
||||
) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream: _,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
let contents: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.flat_map(|message| {
|
||||
let Message { role, content } = message;
|
||||
let role = match role {
|
||||
MessageRole::User => "user",
|
||||
_ => "model",
|
||||
};
|
||||
match content {
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"parts": [{ "text": text }]
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let parts: Vec<Value> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => json!({"text": text}),
|
||||
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
|
||||
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
|
||||
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
vec![json!({ "role": role, "parts": parts })]
|
||||
},
|
||||
MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => {
|
||||
let model_parts: Vec<Value> = tool_results.iter().map(|tool_result| {
|
||||
json!({
|
||||
"functionCall": {
|
||||
"name": tool_result.call.name,
|
||||
"args": tool_result.call.arguments,
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
let function_parts: Vec<Value> = tool_results.into_iter().map(|tool_result| {
|
||||
json!({
|
||||
"functionResponse": {
|
||||
"name": tool_result.call.name,
|
||||
"response": {
|
||||
"name": tool_result.call.name,
|
||||
"content": tool_result.output,
|
||||
}
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
vec![
|
||||
json!({ "role": "model", "parts": model_parts }),
|
||||
json!({ "role": "function", "parts": function_parts }),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({ "contents": contents, "generationConfig": {} });
|
||||
|
||||
if let Some(v) = system_message {
|
||||
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
|
||||
}
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["generationConfig"]["maxOutputTokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["generationConfig"]["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["generationConfig"]["topP"] = v.into();
|
||||
}
|
||||
|
||||
if let Some(functions) = functions {
|
||||
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
|
||||
let function_declarations: Vec<_> = functions
|
||||
.into_iter()
|
||||
.map(|function| {
|
||||
if function.parameters.is_empty_properties() {
|
||||
json!({
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
})
|
||||
} else {
|
||||
json!(function)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelCategory {
|
||||
Gemini,
|
||||
Claude,
|
||||
Mistral,
|
||||
}
|
||||
|
||||
impl FromStr for ModelCategory {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
if s.starts_with("gemini") {
|
||||
Ok(ModelCategory::Gemini)
|
||||
} else if s.starts_with("claude") {
|
||||
Ok(ModelCategory::Claude)
|
||||
} else if s.starts_with("mistral") || s.starts_with("codestral") {
|
||||
Ok(ModelCategory::Mistral)
|
||||
} else {
|
||||
unsupported_model!(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare_gcloud_access_token(
|
||||
client: &reqwest::Client,
|
||||
client_name: &str,
|
||||
adc_file: &Option<String>,
|
||||
) -> Result<()> {
|
||||
if !is_valid_access_token(client_name) {
|
||||
let (token, expires_in) = fetch_access_token(client, adc_file)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch access token")?;
|
||||
let expires_at = Utc::now()
|
||||
+ Duration::try_seconds(expires_in)
|
||||
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
||||
set_access_token(client_name, token, expires_at.timestamp())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_access_token(
|
||||
client: &reqwest::Client,
|
||||
file: &Option<String>,
|
||||
) -> Result<(String, i64)> {
|
||||
let credentials = load_adc(file).await?;
|
||||
let value: Value = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.json(&credentials)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let (Some(access_token), Some(expires_in)) =
|
||||
(value["access_token"].as_str(), value["expires_in"].as_i64())
|
||||
{
|
||||
Ok((access_token.to_string(), expires_in))
|
||||
} else if let Some(err_msg) = value["error_description"].as_str() {
|
||||
bail!("{err_msg}")
|
||||
} else {
|
||||
bail!("Invalid response data: {value}")
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_adc(file: &Option<String>) -> Result<Value> {
|
||||
let adc_file = file
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.or_else(default_adc_file)
|
||||
.ok_or_else(|| anyhow!("No application_default_credentials.json"))?;
|
||||
let data = tokio::fs::read_to_string(adc_file).await?;
|
||||
let data: Value = serde_json::from_str(&data)?;
|
||||
if let (Some(client_id), Some(client_secret), Some(refresh_token)) = (
|
||||
data["client_id"].as_str(),
|
||||
data["client_secret"].as_str(),
|
||||
data["refresh_token"].as_str(),
|
||||
) {
|
||||
Ok(json!({
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}))
|
||||
} else {
|
||||
bail!("Invalid application_default_credentials.json")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn default_adc_file() -> Option<PathBuf> {
|
||||
let mut path = dirs::home_dir()?;
|
||||
path.push(".config");
|
||||
path.push("gcloud");
|
||||
path.push("application_default_credentials.json");
|
||||
Some(path)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn default_adc_file() -> Option<PathBuf> {
|
||||
let mut path = dirs::config_dir()?;
|
||||
path.push("gcloud");
|
||||
path.push("application_default_credentials.json");
|
||||
Some(path)
|
||||
}
|
||||
|
||||
fn strip_model_version(name: &str) -> &str {
|
||||
match name.split_once('@') {
|
||||
Some((v, _)) => v,
|
||||
None => name,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user