Baseline project

This commit is contained in:
2025-10-07 10:45:42 -06:00
parent 88288a98b6
commit 650dbd92e0
54 changed files with 18982 additions and 0 deletions
+32
View File
@@ -0,0 +1,32 @@
use anyhow::{anyhow, Result};
use chrono::Utc;
use indexmap::IndexMap;
use parking_lot::RwLock;
use std::sync::LazyLock;
static ACCESS_TOKENS: LazyLock<RwLock<IndexMap<String, (String, i64)>>> =
LazyLock::new(|| RwLock::new(IndexMap::new()));
pub fn get_access_token(client_name: &str) -> Result<String> {
ACCESS_TOKENS
.read()
.get(client_name)
.map(|(token, _)| token.clone())
.ok_or_else(|| anyhow!("Invalid access token"))
}
pub fn is_valid_access_token(client_name: &str) -> bool {
let access_tokens = ACCESS_TOKENS.read();
let (token, expires_at) = match access_tokens.get(client_name) {
Some(v) => v,
None => return false,
};
!token.is_empty() && Utc::now().timestamp() < *expires_at
}
pub fn set_access_token(client_name: &str, token: String, expires_at: i64) {
let mut access_tokens = ACCESS_TOKENS.write();
let entry = access_tokens.entry(client_name.to_string()).or_default();
entry.0 = token;
entry.1 = expires_at;
}
+82
View File
@@ -0,0 +1,82 @@
use super::openai::*;
use super::*;
use anyhow::Result;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 2] = [
(
"api_base",
"API Base",
Some("e.g. https://{RESOURCE}.openai.azure.com"),
),
("api_key", "API Key", None),
];
}
impl_client_trait!(
AzureOpenAIClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &AzureOpenAIClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-12-01-preview",
&api_base,
self_.model.real_name()
);
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.header("api-key", api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;
let url = format!(
"{}/openai/deployments/{}/embeddings?api-version=2024-10-21",
&api_base,
self_.model.real_name()
);
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.header("api-key", api_key);
Ok(request_data)
}
+643
View File
@@ -0,0 +1,643 @@
use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
use anyhow::{bail, Context, Result};
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
use chrono::{DateTime, Utc};
use futures_util::StreamExt;
use indexmap::IndexMap;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
pub name: Option<String>,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub region: Option<String>,
pub session_token: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region);
config_get_fn!(session_token, get_session_token);
pub const PROMPTS: [PromptAction<'static>; 3] = [
("access_key_id", "AWS Access Key ID", None),
("secret_access_key", "AWS Secret Access Key", None),
("region", "AWS Region", None),
];
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let session_token = self.get_session_token().ok();
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let model_name = &self.model.real_name();
let uri = if data.stream {
format!("/model/{model_name}/converse-stream")
} else {
format!("/model/{model_name}/converse")
};
let body = build_chat_completions_body(data, &self.model)?;
let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
body,
} = request_data;
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
session_token,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let session_token = self.get_session_token().ok();
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let uri = format!("/model/{}/invoke", self.model.real_name());
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"texts": data.texts,
"input_type": input_type,
});
let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
body,
} = request_data;
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
session_token,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
}
#[async_trait::async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let builder = self.embeddings_builder(client, data)?;
embeddings(builder).await
}
}
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_chat_completions(&data)
}
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
bail!("Invalid response data: {data}");
}
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let mut stream = res.bytes_stream();
let mut buffer = BytesMut::new();
let mut decoder = MessageFrameDecoder::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buffer.extend_from_slice(&chunk);
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
let response_headers = parse_response_headers(&message)?;
let message_type = response_headers.message_type.as_str();
let smithy_type = response_headers.smithy_type.as_str();
match (message_type, smithy_type) {
("event", _) => {
let data: Value = serde_json::from_slice(message.payload())?;
debug!("stream-data: {smithy_type} {data}");
match smithy_type {
"contentBlockStart" => {
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
if let (Some(id), Some(name)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
) {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_arguments.clear();
function_name = name.into();
function_id = id.into();
}
}
}
"contentBlockDelta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let Some(text) =
data["delta"]["reasoningContent"]["text"].as_str()
{
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
function_arguments.push_str(input);
}
}
"contentBlockStop" => {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
("exception", _) => {
let payload = base64_decode(message.payload())?;
let data = String::from_utf8_lossy(&payload);
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
}
_ => {
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
}
}
}
}
Ok(())
}
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<Vec<f32>>,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream: _,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": [ { "text": strip_think_tag(&text) } ] })]
}
MessageContent::Text(text) => vec![json!({
"role": role,
"content": [
{
"text": text,
}
],
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"text": text})
}
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"image": {
"format": mime_type.replace("image/", ""),
"source": {
"bytes": data,
}
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
assistant_parts.push(json!({
"text": text,
}))
}
for tool_result in tool_results {
assistant_parts.push(json!({
"toolUse": {
"toolUseId": tool_result.call.id,
"name": tool_result.call.name,
"input": tool_result.call.arguments,
}
}));
user_parts.push(json!({
"toolResult": {
"toolUseId": tool_result.call.id,
"content": [
{
"json": tool_result.output,
}
]
}
}));
}
vec![
json!({
"role": "assistant",
"content": assistant_parts,
}),
json!({
"role": "user",
"content": user_parts,
}),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"inferenceConfig": {},
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = json!([
{
"text": v,
}
])
}
if let Some(v) = model.max_tokens_param() {
body["inferenceConfig"]["maxTokens"] = v.into();
}
if let Some(v) = temperature {
body["inferenceConfig"]["temperature"] = v.into();
}
if let Some(v) = top_p {
body["inferenceConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
let tools: Vec<_> = functions
.iter()
.map(|v| {
json!({
"toolSpec": {
"name": v.name,
"description": v.description,
"inputSchema": {
"json": v.parameters,
},
}
})
})
.collect();
body["toolConfig"] = json!({
"tools": tools,
})
}
Ok(body)
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = String::new();
let mut reasoning = None;
let mut tool_calls = vec![];
if let Some(array) = data["output"]["message"]["content"].as_array() {
for item in array {
if let Some(v) = item["text"].as_str() {
if !text.is_empty() {
text.push_str("\n\n");
}
text.push_str(v);
} else if let Some(reasoning_text) =
item["reasoningContent"]["reasoningText"].as_object()
{
if let Some(text) = json_str_from_map(reasoning_text, "text") {
reasoning = Some(text.to_string());
}
} else if let Some(tool_use) = item["toolUse"].as_object() {
if let (Some(id), Some(name), Some(input)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
tool_use.get("input"),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
}
}
}
}
if let Some(reasoning) = reasoning {
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: None,
input_tokens: data["usage"]["inputTokens"].as_u64(),
output_tokens: data["usage"]["outputTokens"].as_u64(),
};
Ok(output)
}
#[derive(Debug)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
region: String,
session_token: Option<String>,
}
#[derive(Debug)]
struct AwsRequest {
method: Method,
host: String,
service: String,
uri: String,
querystring: String,
headers: IndexMap<String, String>,
body: String,
}
fn aws_fetch(
client: &ReqwestClient,
credentials: &AwsCredentials,
request: AwsRequest,
) -> Result<RequestBuilder> {
let AwsRequest {
method,
host,
service,
uri,
querystring,
mut headers,
body,
} = request;
let region = &credentials.region;
let endpoint = format!("https://{host}{uri}");
let now: DateTime<Utc> = Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = amz_date[0..8].to_string();
headers.insert("host".into(), host.clone());
headers.insert("x-amz-date".into(), amz_date.clone());
if let Some(token) = credentials.session_token.clone() {
headers.insert("x-amz-security-token".into(), token);
}
let canonical_headers = headers
.iter()
.map(|(key, value)| format!("{key}:{value}\n"))
.collect::<Vec<_>>()
.join("");
let signed_headers = headers
.iter()
.map(|(key, _)| key.as_str())
.collect::<Vec<_>>()
.join(";");
let payload_hash = sha256(&body);
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
encode_uri(&uri),
querystring,
canonical_headers,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request");
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm,
amz_date,
credential_scope,
sha256(&canonical_request)
);
let signing_key = gen_signing_key(
&credentials.secret_access_key,
&date_stamp,
region,
&service,
);
let signature = hmac_sha256(&signing_key, &string_to_sign);
let signature = hex_encode(&signature);
let authorization_header = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
);
headers.insert("authorization".into(), authorization_header);
debug!("Request {endpoint} {body}");
let mut request_builder = client.request(method, endpoint).body(body);
for (key, value) in &headers {
request_builder = request_builder.header(key, value);
}
Ok(request_builder)
}
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp);
let k_region = hmac_sha256(&k_date, region);
let k_service = hmac_sha256(&k_region, service);
hmac_sha256(&k_service, "aws4_request")
}
+353
View File
@@ -0,0 +1,353 @@
use super::*;
use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1";
#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
ClaudeClient,
(
prepare_chat_completions,
claude_chat_completions,
claude_chat_completions_streaming
),
(noop_prepare_embeddings, noop_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &ClaudeClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/messages", api_base.trim_end_matches('/'));
let body = claude_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("anthropic-version", "2023-06-01");
request_data.header("x-api-key", api_key);
Ok(request_data)
}
pub async fn claude_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
claude_extract_chat_completions(&data)
}
pub async fn claude_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let handle = |message: SseMessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
match typ {
"content_block_start" => {
if let (Some("tool_use"), Some(name), Some(id)) = (
data["content_block"]["type"].as_str(),
data["content_block"]["name"].as_str(),
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name = name.into();
function_arguments.clear();
function_id = id.into();
}
}
"content_block_delta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let Some(text) = data["delta"]["thinking"].as_str() {
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
} else if let (true, Some(partial_json)) = (
!function_name.is_empty(),
data["delta"]["partial_json"].as_str(),
) {
function_arguments.push_str(partial_json);
}
}
"content_block_stop" => {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
if !function_name.is_empty() {
let arguments: Value = if function_arguments.is_empty() {
json!({})
} else {
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?
};
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
pub fn claude_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": strip_think_tag(&text) })]
}
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"type": "text", "text": text})
}
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
assistant_parts.push(json!({
"type": "text",
"text": text,
}))
}
for tool_result in tool_results {
assistant_parts.push(json!({
"type": "tool_use",
"id": tool_result.call.id,
"name": tool_result.call.name,
"input": tool_result.call.arguments,
}));
user_parts.push(json!({
"type": "tool_result",
"tool_use_id": tool_result.call.id,
"content": tool_result.output.to_string(),
}));
}
vec![
json!({
"role": "assistant",
"content": assistant_parts,
}),
json!({
"role": "user",
"content": user_parts,
}),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"model": model.real_name(),
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = v.into();
}
if let Some(v) = model.max_tokens_param() {
body["max_tokens"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"name": v.name,
"description": v.description,
"input_schema": v.parameters,
})
})
.collect();
}
Ok(body)
}
pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = String::new();
let mut reasoning = None;
let mut tool_calls = vec![];
if let Some(list) = data["content"].as_array() {
for item in list {
match item["type"].as_str() {
Some("thinking") => {
if let Some(v) = item["thinking"].as_str() {
reasoning = Some(v.to_string());
}
}
Some("text") => {
if let Some(v) = item["text"].as_str() {
if !text.is_empty() {
text.push_str("\n\n");
}
text.push_str(v);
}
}
Some("tool_use") => {
if let (Some(name), Some(input), Some(id)) = (
item["name"].as_str(),
item.get("input"),
item["id"].as_str(),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
));
}
}
_ => {}
}
}
}
if let Some(reasoning) = reasoning {
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok(output)
}
+255
View File
@@ -0,0 +1,255 @@
use super::openai::*;
use super::openai_compatible::*;
use super::*;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.cohere.ai/v2";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl CohereClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
CohereClient,
(
prepare_chat_completions,
chat_completions,
chat_completions_streaming
),
(prepare_embeddings, embeddings),
(prepare_rerank, generic_rerank),
);
fn prepare_chat_completions(
self_: &CohereClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/chat", api_base.trim_end_matches('/'));
let mut body = openai_build_chat_completions_body(data, &self_.model);
if let Some(obj) = body.as_object_mut() {
if let Some(top_p) = obj.remove("top_p") {
obj.insert("p".to_string(), top_p);
}
}
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/embed", api_base.trim_end_matches('/'));
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"model": self_.model.real_name(),
"texts": data.texts,
"input_type": input_type,
"embedding_types": ["float"],
});
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/rerank", api_base.trim_end_matches('/'));
let body = generic_build_rerank_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
async fn chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_chat_completions(&data)
}
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SseMessage| -> Result<bool> {
if message.data == "[DONE]" {
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
match typ {
"content-delta" => {
if let Some(text) = data["delta"]["message"]["content"]["text"].as_str() {
handler.text(text)?;
}
}
"tool-plan-delta" => {
if let Some(text) = data["delta"]["message"]["tool_plan"].as_str() {
handler.text(text)?;
}
}
"tool-call-start" => {
if let (Some(function), Some(id)) = (
data["delta"]["message"]["tool_calls"]["function"].as_object(),
data["delta"]["message"]["tool_calls"]["id"].as_str(),
) {
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
function_name = name.to_string();
}
function_id = id.to_string();
}
}
"tool-call-delta" => {
if let Some(text) =
data["delta"]["message"]["tool_calls"]["function"]["arguments"].as_str()
{
function_arguments.push_str(text);
}
}
"tool-call-end" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name.clear();
function_arguments.clear();
function_id.clear();
}
_ => {}
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings.float)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: EmbeddingsResBodyEmbeddings,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbeddings {
float: Vec<Vec<f32>>,
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = data["message"]["content"][0]["text"]
.as_str()
.unwrap_or_default()
.to_string();
let mut tool_calls = vec![];
if let Some(calls) = data["message"]["tool_calls"].as_array() {
if text.is_empty() {
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
text = tool_plain.to_string();
}
}
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
arguments,
Some(id.to_string()),
));
}
}
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
};
Ok(output)
}
+678
View File
@@ -0,0 +1,678 @@
use super::*;
use crate::{
config::{Config, GlobalConfig, Input},
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
render::render_stream,
utils::*,
};
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use indexmap::IndexMap;
use inquire::{
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::mpsc::unbounded_channel;
const MODELS_YAML: &str = include_str!("../../models.yaml");
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
Config::local_models_override()
.ok()
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
});
static EMBEDDING_MODEL_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap()
});
static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/").unwrap());
#[async_trait::async_trait]
pub trait Client: Sync + Send {
fn global_config(&self) -> &GlobalConfig;
fn extra_config(&self) -> Option<&ExtraConfig>;
fn patch_config(&self) -> Option<&RequestPatch>;
fn name(&self) -> &str;
fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let extra = self.extra_config();
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
builder = set_proxy(builder, proxy)?;
}
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
builder = builder.user_agent(user_agent);
}
let client = builder
.connect_timeout(Duration::from_secs(timeout))
.build()
.with_context(|| "Failed to build client")?;
Ok(client)
}
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
if self.global_config().read().dry_run {
let content = input.echo_messages();
return Ok(ChatCompletionsOutput::new(&content));
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), false)?;
self.chat_completions_inner(&client, data)
.await
.with_context(|| "Failed to call chat-completions api")
}
async fn chat_completions_streaming(
&self,
input: &Input,
handler: &mut SseHandler,
) -> Result<()> {
let abort_signal = handler.abort();
let input = input.clone();
tokio::select! {
ret = async {
if self.global_config().read().dry_run {
let content = input.echo_messages();
handler.text(&content)?;
return Ok(());
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), true)?;
self.chat_completions_streaming_inner(&client, handler, data).await
} => {
handler.done();
ret.with_context(|| "Failed to call chat-completions api")
}
_ = wait_abort_signal(&abort_signal) => {
handler.done();
Ok(())
},
}
}
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
let client = self.build_client()?;
self.embeddings_inner(&client, data)
.await
.context("Failed to call embeddings api")
}
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
let client = self.build_client()?;
self.rerank_inner(&client, data)
.await
.context("Failed to call rerank api")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput>;
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()>;
async fn embeddings_inner(
&self,
_client: &ReqwestClient,
_data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
async fn rerank_inner(
&self,
_client: &ReqwestClient,
_data: &RerankData,
) -> Result<RerankOutput> {
bail!("The client doesn't support rerank api")
}
fn request_builder(
&self,
client: &reqwest::Client,
mut request_data: RequestData,
) -> RequestBuilder {
self.patch_request_data(&mut request_data);
request_data.into_builder(client)
}
fn patch_request_data(&self, request_data: &mut RequestData) {
let model_type = self.model().model_type();
if let Some(patch) = self.model().patch() {
request_data.apply_patch(patch.clone());
}
let patch_map = std::env::var(get_env_name(&format!(
"patch_{}_{}",
self.model().client_name(),
model_type.api_name(),
)))
.ok()
.and_then(|v| serde_json::from_str(&v).ok())
.or_else(|| {
self.patch_config()
.and_then(|v| model_type.extract_patch(v))
.cloned()
});
let patch_map = match patch_map {
Some(v) => v,
_ => return,
};
for (key, patch) in patch_map {
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
if let Ok(true) = regex.is_match(self.model().name()) {
request_data.apply_patch(patch);
return;
}
}
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self::OpenAIConfig(OpenAIConfig::default())
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtraConfig {
pub proxy: Option<String>,
pub connect_timeout: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct RequestPatch {
pub chat_completions: Option<ApiPatch>,
pub embeddings: Option<ApiPatch>,
pub rerank: Option<ApiPatch>,
}
pub type ApiPatch = IndexMap<String, Value>;
pub struct RequestData {
pub url: String,
pub headers: IndexMap<String, String>,
pub body: Value,
}
impl RequestData {
pub fn new<T>(url: T, body: Value) -> Self
where
T: std::fmt::Display,
{
Self {
url: url.to_string(),
headers: Default::default(),
body,
}
}
pub fn bearer_auth<T>(&mut self, auth: T)
where
T: std::fmt::Display,
{
self.headers
.insert("authorization".into(), format!("Bearer {auth}"));
}
pub fn header<K, V>(&mut self, key: K, value: V)
where
K: std::fmt::Display,
V: std::fmt::Display,
{
self.headers.insert(key.to_string(), value.to_string());
}
pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder {
let RequestData { url, headers, body } = self;
debug!("Request {url} {body}");
let mut builder = client.post(url);
for (key, value) in headers {
builder = builder.header(key, value);
}
builder = builder.json(&body);
builder
}
pub fn apply_patch(&mut self, patch: Value) {
if let Some(patch_url) = patch["url"].as_str() {
self.url = patch_url.into();
}
if let Some(patch_body) = patch.get("body") {
json_patch::merge(&mut self.body, patch_body)
}
if let Some(patch_headers) = patch["headers"].as_object() {
for (key, value) in patch_headers {
if let Some(value) = value.as_str() {
self.header(key, value)
} else if value.is_null() {
self.headers.swap_remove(key);
}
}
}
}
}
#[derive(Debug)]
pub struct ChatCompletionsData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub functions: Option<Vec<FunctionDeclaration>>,
pub stream: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionsOutput {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub id: Option<String>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
}
impl ChatCompletionsOutput {
pub fn new(text: &str) -> Self {
Self {
text: text.to_string(),
..Default::default()
}
}
}
#[derive(Debug)]
pub struct EmbeddingsData {
pub texts: Vec<String>,
pub query: bool,
}
impl EmbeddingsData {
pub fn new(texts: Vec<String>, query: bool) -> Self {
Self { texts, query }
}
}
pub type EmbeddingsOutput = Vec<Vec<f32>>;
#[derive(Debug)]
pub struct RerankData {
pub query: String,
pub documents: Vec<String>,
pub top_n: usize,
}
impl RerankData {
pub fn new(query: String, documents: Vec<String>, top_n: usize) -> Self {
Self {
query,
documents,
top_n,
}
}
}
pub type RerankOutput = Vec<RerankResult>;
#[derive(Debug, Deserialize)]
pub struct RerankResult {
pub index: usize,
pub relevance_score: f64,
}
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>);
pub async fn create_config(
prompts: &[PromptAction<'static>],
client: &str,
) -> Result<(String, Value)> {
let mut config = json!({
"type": client,
});
for (key, desc, help_message) in prompts {
let env_name = format!("{client}_{key}").to_ascii_uppercase();
let required = std::env::var(&env_name).is_err();
let value = prompt_input_string(desc, required, *help_message)?;
if !value.is_empty() {
config[key] = value.into();
}
}
let model = set_client_models_config(&mut config, client).await?;
let clients = json!(vec![config]);
Ok((model, clients))
}
pub async fn create_openai_compatible_client_config(
client: &str,
) -> Result<Option<(String, Value)>> {
let api_base = OPENAI_COMPATIBLE_PROVIDERS
.into_iter()
.find(|(name, _)| client == *name)
.map(|(_, api_base)| api_base)
.unwrap_or("http(s)://{API_ADDR}/v1");
let name = if client == OpenAICompatibleClient::NAME {
let value = prompt_input_string("Provider Name", true, None)?;
value.replace(' ', "-")
} else {
client.to_string()
};
let mut config = json!({
"type": OpenAICompatibleClient::NAME,
"name": &name,
});
let api_base = if api_base.contains('{') {
prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))?
} else {
api_base.to_string()
};
config["api_base"] = api_base.into();
let api_key = prompt_input_string("API Key", false, None)?;
if !api_key.is_empty() {
config["api_key"] = api_key.into();
}
let model = set_client_models_config(&mut config, &name).await?;
let clients = json!(vec![config]);
Ok(Some((model, clients)))
}
pub async fn call_chat_completions(
input: &Input,
print: bool,
extract_code: bool,
client: &dyn Client,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let ret = abortable_run_with_spinner(
client.chat_completions(input.clone()),
"Generating",
abort_signal,
)
.await;
match ret {
Ok(ret) => {
let ChatCompletionsOutput {
mut text,
tool_calls,
..
} = ret;
if !text.is_empty() {
if extract_code {
text = extract_code_block(&strip_think_tag(&text)).to_string();
}
if print {
client.global_config().read().print_markdown(&text)?;
}
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
}
Err(err) => Err(err),
}
}
pub async fn call_chat_completions_streaming(
input: &Input,
client: &dyn Client,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel();
let mut handler = SseHandler::new(tx, abort_signal.clone());
let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, client.global_config(), abort_signal.clone()),
);
if handler.abort().aborted() {
bail!("Aborted.");
}
render_ret?;
let (text, tool_calls) = handler.take();
match send_ret {
Ok(_) => {
if !text.is_empty() && !text.ends_with('\n') {
println!();
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
}
Err(err) => {
if !text.is_empty() {
println!();
}
Err(err)
}
}
}
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
bail!("The client doesn't support embeddings api")
}
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
bail!("The client doesn't support rerank api")
}
pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
bail!("The client doesn't support rerank api")
}
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
if (200..300).contains(&status) {
return Ok(());
}
debug!("Invalid response, status: {status}, data: {data}");
if let Some(error) = data["error"].as_object() {
if let (Some(typ), Some(message)) = (
json_str_from_map(error, "type"),
json_str_from_map(error, "message"),
) {
bail!("{message} (type: {typ})");
} else if let (Some(typ), Some(message)) = (
json_str_from_map(error, "code"),
json_str_from_map(error, "message"),
) {
bail!("{message} (code: {typ})");
}
} else if let Some(error) = data["errors"][0].as_object() {
if let (Some(code), Some(message)) = (
error.get("code").and_then(|v| v.as_u64()),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {code})")
}
} else if let Some(error) = data[0]["error"].as_object() {
if let (Some(status), Some(message)) = (
json_str_from_map(error, "status"),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {status})")
}
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
{
bail!("{detail} (status: {status})");
} else if let Some(error) = data["error"].as_str() {
bail!("{error}");
} else if let Some(message) = data["message"].as_str() {
bail!("{message}");
}
bail!("Invalid response data: {data} (status: {status})");
}
pub fn json_str_from_map<'a>(
map: &'a serde_json::Map<String, Value>,
field_name: &str,
) -> Option<&'a str> {
map.get(field_name).and_then(|v| v.as_str())
}
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
let models: Vec<String> = provider
.models
.iter()
.filter(|v| v.model_type == "chat")
.map(|v| v.name.clone())
.collect();
let model_name = select_model(models)?;
return Ok(format!("{client}:{model_name}"));
}
let mut model_names = vec![];
if let (Some(true), Some(api_base), api_key) = (
client_config["type"]
.as_str()
.map(|v| v == OpenAICompatibleClient::NAME),
client_config["api_base"].as_str(),
client_config["api_key"]
.as_str()
.map(|v| v.to_string())
.or_else(|| {
let env_name = format!("{client}_api_key").to_ascii_uppercase();
std::env::var(&env_name).ok()
}),
) {
match abortable_run_with_spinner(
fetch_models(api_base, api_key.as_deref()),
"Fetching models",
create_abort_signal(),
)
.await
{
Ok(fetched_models) => {
model_names = MultiSelect::new("LLMs to include (required):", fetched_models)
.with_validator(|list: &[ListOption<&String>]| {
if list.is_empty() {
Ok(Validation::Invalid(
"At least one item must be selected".into(),
))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
Err(err) => {
eprintln!("✗ Fetch models failed: {err}");
}
}
}
if model_names.is_empty() {
model_names = prompt_input_string(
"LLMs to add",
true,
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
)?
.split(',')
.filter_map(|v| {
let v = v.trim();
if v.is_empty() {
None
} else {
Some(v.to_string())
}
})
.collect::<Vec<_>>();
}
if model_names.is_empty() {
bail!("No models");
}
let models: Vec<Value> = model_names
.iter()
.map(|v| {
let l = v.to_lowercase();
if l.contains("rank") {
json!({
"name": v,
"type": "reranker",
})
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
json!({
"name": v,
"type": "embedding",
"default_chunk_size": 1000,
"max_batch_size": 100
})
} else if v.contains("vision") {
json!({
"name": v,
"supports_vision": true
})
} else {
json!({
"name": v,
})
}
})
.collect();
client_config["models"] = models.into();
let model_name = select_model(model_names)?;
Ok(format!("{client}:{model_name}"))
}
fn select_model(model_names: Vec<String>) -> Result<String> {
if model_names.is_empty() {
bail!("No models");
}
let model = if model_names.len() == 1 {
model_names[0].clone()
} else {
Select::new("Default Model (required):", model_names).prompt()?
};
Ok(model)
}
fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result<String> {
let desc = if required {
format!("{desc} (required):")
} else {
format!("{desc} (optional):")
};
let mut text = Text::new(&desc);
if required {
text = text.with_validator(required!("This field is required"))
}
if let Some(help_message) = help_message {
text = text.with_help_message(help_message);
}
let text = text.prompt()?;
Ok(text)
}
+136
View File
@@ -0,0 +1,136 @@
use super::vertexai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
GeminiClient,
(
prepare_chat_completions,
gemini_chat_completions,
gemini_chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &GeminiClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
let url = format!(
"{}/models/{}:{}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
func
);
let body = gemini_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("x-goog-api-key", api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
);
let model_id = format!("models/{}", self_.model.real_name());
let requests: Vec<_> = data
.texts
.iter()
.map(|text| {
json!({
"model": model_id,
"content": {
"parts": [
{
"text": text
}
]
},
})
})
.collect();
let body = json!({
"requests": requests,
});
let request_data = RequestData::new(url, body);
Ok(request_data)
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body
.embeddings
.into_iter()
.map(|embedding| embedding.values)
.collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<EmbeddingsResBodyEmbedding>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
values: Vec<f32>,
}
+245
View File
@@ -0,0 +1,245 @@
#[macro_export]
macro_rules! register_client {
(
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
) => {
$(
mod $module;
)+
$(
use self::$module::$config;
)+
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
$(
#[serde(rename = $name)]
$config($config),
)+
#[serde(other)]
Unknown,
}
$(
#[derive(Debug)]
pub struct $client {
global_config: $crate::config::GlobalConfig,
config: $config,
model: $crate::client::Model,
}
impl $client {
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == model.client_name() {
return Some(c.clone())
}
}
None
})?;
Some(Box::new(Self {
global_config: global_config.clone(),
config,
model: model.clone(),
}))
}
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| {
v.provider == $name ||
($name == OpenAICompatibleClient::NAME
&& local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default())
}) {
return Model::from_config(client_name, &v.models);
}
vec![]
} else {
Model::from_config(client_name, &local_config.models)
}
}
pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
}
)+
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
None
$(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Invalid model '{}'", model.id())
})
}
pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name));
client_types
}
pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
return create_config(&$client::PROMPTS, $client::NAME).await
}
)+
if let Some(ret) = create_openai_compatible_client_config(client).await? {
return Ok(ret);
}
anyhow::bail!("Unknown client '{}'", client)
}
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
let names = ALL_CLIENT_NAMES.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
ClientConfig::Unknown => vec![],
})
.collect()
});
names.iter().collect()
}
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
let models = ALL_MODELS.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
});
models.iter().collect()
}
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
}
};
}
#[macro_export]
macro_rules! client_common_fns {
() => {
fn global_config(&self) -> &$crate::config::GlobalConfig {
&self.global_config
}
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
self.config.extra.as_ref()
}
fn patch_config(&self) -> Option<&$crate::client::RequestPatch> {
self.config.patch.as_ref()
}
fn name(&self) -> &str {
Self::name(&self.config)
}
fn model(&self) -> &Model {
&self.model
}
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
};
}
#[macro_export]
macro_rules! impl_client_trait {
(
$client:ident,
($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path),
($prepare_embeddings:path, $embeddings:path),
($prepare_rerank:path, $rerank:path),
) => {
#[async_trait::async_trait]
impl $crate::client::Client for $crate::client::$client {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &reqwest::Client,
data: $crate::client::ChatCompletionsData,
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::client::SseHandler,
data: $crate::client::ChatCompletionsData,
) -> Result<()> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::EmbeddingsData,
) -> Result<$crate::client::EmbeddingsOutput> {
let request_data = $prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data);
$embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let request_data = $prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
$rerank(builder, self.model()).await
}
}
};
}
#[macro_export]
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> anyhow::Result<String> {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
std::env::var(&env_name)
.ok()
.or_else(|| self.config.$field_name.clone())
.ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name)))
}
};
}
#[macro_export]
macro_rules! unsupported_model {
($name:expr) => {
anyhow::bail!("Unsupported model '{}'", $name)
};
}
+235
View File
@@ -0,0 +1,235 @@
use super::Model;
use crate::{function::ToolResult, multiline_text, utils::dimmed_text};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: MessageContent,
}
impl Default for Message {
fn default() -> Self {
Self {
role: MessageRole::User,
content: MessageContent::Text(String::new()),
}
}
}
impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content }
}
pub fn merge_system(&mut self, system: MessageContent) {
match (&mut self.content, system) {
(MessageContent::Text(text), MessageContent::Text(system_text)) => {
self.content = MessageContent::Array(vec![
MessageContentPart::Text { text: system_text },
MessageContentPart::Text {
text: text.to_string(),
},
])
}
(MessageContent::Array(list), MessageContent::Text(system_text)) => {
list.insert(0, MessageContentPart::Text { text: system_text })
}
(MessageContent::Text(text), MessageContent::Array(mut system_list)) => {
system_list.push(MessageContentPart::Text {
text: text.to_string(),
});
self.content = MessageContent::Array(system_list);
}
(MessageContent::Array(list), MessageContent::Array(mut system_list)) => {
system_list.append(list);
self.content = MessageContent::Array(system_list);
}
_ => {}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
Assistant,
User,
Tool,
}
#[allow(dead_code)]
impl MessageRole {
pub fn is_system(&self) -> bool {
matches!(self, MessageRole::System)
}
pub fn is_user(&self) -> bool {
matches!(self, MessageRole::User)
}
pub fn is_assistant(&self) -> bool {
matches!(self, MessageRole::Assistant)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
ToolCalls(MessageContentToolCalls),
}
impl MessageContent {
pub fn render_input(
&self,
resolve_url_fn: impl Fn(&str) -> String,
agent_info: &Option<(String, Vec<String>)>,
) -> String {
match self {
MessageContent::Text(text) => multiline_text(text),
MessageContent::Array(list) => {
let (mut concated_text, mut files) = (String::new(), vec![]);
for item in list {
match item {
MessageContentPart::Text { text } => {
concated_text = format!("{concated_text} {text}")
}
MessageContentPart::ImageUrl { image_url } => {
files.push(resolve_url_fn(&image_url.url))
}
}
}
if !concated_text.is_empty() {
concated_text = format!(" -- {}", multiline_text(&concated_text))
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut lines = vec![];
if !text.is_empty() {
lines.push(text.clone())
}
for tool_result in tool_results {
let mut parts = vec!["Call".to_string()];
if let Some((agent_name, functions)) = agent_info {
if functions.contains(&tool_result.call.name) {
parts.push(agent_name.clone())
}
}
parts.push(tool_result.call.name.clone());
parts.push(tool_result.call.arguments.to_string());
lines.push(dimmed_text(&parts.join(" ")));
}
lines.join("\n")
}
}
}
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
match self {
MessageContent::Text(text) => *text = replace_fn(text),
MessageContent::Array(list) => {
if list.is_empty() {
list.push(MessageContentPart::Text {
text: replace_fn(""),
})
} else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) {
*text = replace_fn(text)
}
}
MessageContent::ToolCalls(_) => {}
}
}
pub fn to_text(&self) -> String {
match self {
MessageContent::Text(text) => text.to_string(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
if let MessageContentPart::Text { text } = item {
parts.push(text.clone())
}
}
parts.join("\n\n")
}
MessageContent::ToolCalls(_) => String::new(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContentToolCalls {
pub tool_results: Vec<ToolResult>,
pub text: String,
pub sequence: bool,
}
impl MessageContentToolCalls {
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
Self {
tool_results,
text,
sequence: false,
}
}
pub fn merge(&mut self, tool_results: Vec<ToolResult>, _text: String) {
self.tool_results.extend(tool_results);
self.text.clear();
self.sequence = true;
}
}
pub fn patch_messages(messages: &mut Vec<Message>, model: &Model) {
if messages.is_empty() {
return;
}
if let Some(prefix) = model.system_prompt_prefix() {
if messages[0].role.is_system() {
messages[0].merge_system(MessageContent::Text(prefix.to_string()));
} else {
messages.insert(
0,
Message {
role: MessageRole::System,
content: MessageContent::Text(prefix.to_string()),
},
);
}
}
if model.no_system_message() && messages[0].role.is_system() {
let system_message = messages.remove(0);
if let (Some(message), system) = (messages.get_mut(0), system_message.content) {
message.merge_system(system);
}
}
}
pub fn extract_system_message(messages: &mut Vec<Message>) -> Option<String> {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
return Some(system_message.content.to_text());
}
None
}
+62
View File
@@ -0,0 +1,62 @@
mod access_token;
mod common;
mod message;
#[macro_use]
mod macros;
mod model;
mod stream;
pub use crate::function::ToolCall;
pub use common::*;
pub use message::*;
pub use model::*;
pub use stream::*;
register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
(gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient),
(cohere, "cohere", CohereConfig, CohereClient),
(
azure_openai,
"azure-openai",
AzureOpenAIConfig,
AzureOpenAIClient
),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
);
pub const OPENAI_COMPATIBLE_PROVIDERS: [(&str, &str); 18] = [
("ai21", "https://api.ai21.com/studio/v1"),
(
"cloudflare",
"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1",
),
("deepinfra", "https://api.deepinfra.com/v1/openai"),
("deepseek", "https://api.deepseek.com"),
("ernie", "https://qianfan.baidubce.com/v2"),
("github", "https://models.inference.ai.azure.com"),
("groq", "https://api.groq.com/openai/v1"),
("hunyuan", "https://api.hunyuan.cloud.tencent.com/v1"),
("minimax", "https://api.minimax.chat/v1"),
("mistral", "https://api.mistral.ai/v1"),
("moonshot", "https://api.moonshot.cn/v1"),
("openrouter", "https://openrouter.ai/api/v1"),
("perplexity", "https://api.perplexity.ai"),
(
"qianwen",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
),
("xai", "https://api.x.ai/v1"),
("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
// RAG-dedicated
("jina", "https://api.jina.ai/v1"),
("voyageai", "https://api.voyageai.com/v1"),
];
+407
View File
@@ -0,0 +1,407 @@
use super::{
list_all_models, list_client_names,
message::{Message, MessageContent, MessageContentPart},
ApiPatch, MessageContentToolCalls, RequestPatch,
};
use crate::config::Config;
use crate::utils::{estimate_token_length, strip_think_tag};
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Display;
const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
#[derive(Debug, Clone)]
pub struct Model {
client_name: String,
data: ModelData,
}
impl Default for Model {
fn default() -> Self {
Model::new("", "")
}
}
impl Model {
pub fn new(client_name: &str, name: &str) -> Self {
Self {
client_name: client_name.into(),
data: ModelData::new(name),
}
}
pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
models
.iter()
.map(|v| Model {
client_name: client_name.to_string(),
data: v.clone(),
})
.collect()
}
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
let models = list_all_models(config);
let (client_name, model_name) = match model_id.split_once(':') {
Some((client_name, model_name)) => {
if model_name.is_empty() {
(client_name, None)
} else {
(client_name, Some(model_name))
}
}
None => (model_id, None),
};
match model_name {
Some(model_name) => {
if let Some(model) = models.iter().find(|v| v.id() == model_id) {
if model.model_type() == model_type {
return Ok((*model).clone());
} else {
bail!("Model '{model_id}' is not a {model_type} model")
}
}
if list_client_names(config)
.into_iter()
.any(|v| *v == client_name)
&& model_type.can_create_from_name()
{
let mut new_model = Self::new(client_name, model_name);
new_model.data.model_type = model_type.to_string();
return Ok(new_model);
}
}
None => {
if let Some(found) = models
.iter()
.find(|v| v.client_name == client_name && v.model_type() == model_type)
{
return Ok((*found).clone());
}
}
};
bail!("Unknown {model_type} model '{model_id}'")
}
pub fn id(&self) -> String {
if self.data.name.is_empty() {
self.client_name.to_string()
} else {
format!("{}:{}", self.client_name, self.data.name)
}
}
pub fn client_name(&self) -> &str {
&self.client_name
}
pub fn name(&self) -> &str {
&self.data.name
}
pub fn real_name(&self) -> &str {
self.data.real_name.as_deref().unwrap_or(&self.data.name)
}
pub fn model_type(&self) -> ModelType {
if self.data.model_type.starts_with("embed") {
ModelType::Embedding
} else if self.data.model_type.starts_with("rerank") {
ModelType::Reranker
} else {
ModelType::Chat
}
}
pub fn data(&self) -> &ModelData {
&self.data
}
pub fn data_mut(&mut self) -> &mut ModelData {
&mut self.data
}
pub fn description(&self) -> String {
match self.model_type() {
ModelType::Chat => {
let ModelData {
max_input_tokens,
max_output_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = &self.data;
let max_input_tokens = stringify_option_value(max_input_tokens);
let max_output_tokens = stringify_option_value(max_output_tokens);
let input_price = stringify_option_value(input_price);
let output_price = stringify_option_value(output_price);
let mut capabilities = vec![];
if *supports_vision {
capabilities.push('👁');
};
if *supports_function_calling {
capabilities.push('⚒');
};
let capabilities: String = capabilities
.into_iter()
.map(|v| format!("{v} "))
.collect::<Vec<String>>()
.join("");
format!(
"{max_input_tokens:>8} / {max_output_tokens:>8} | {input_price:>6} / {output_price:>6} {capabilities:>6}"
)
}
ModelType::Embedding => {
let ModelData {
input_price,
max_tokens_per_chunk,
max_batch_size,
..
} = &self.data;
let max_tokens = stringify_option_value(max_tokens_per_chunk);
let max_batch = stringify_option_value(max_batch_size);
let price = stringify_option_value(input_price);
format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}")
}
ModelType::Reranker => String::new(),
}
}
pub fn patch(&self) -> Option<&Value> {
self.data.patch.as_ref()
}
pub fn max_input_tokens(&self) -> Option<usize> {
self.data.max_input_tokens
}
pub fn max_output_tokens(&self) -> Option<isize> {
self.data.max_output_tokens
}
pub fn no_stream(&self) -> bool {
self.data.no_stream
}
pub fn no_system_message(&self) -> bool {
self.data.no_system_message
}
pub fn system_prompt_prefix(&self) -> Option<&str> {
self.data.system_prompt_prefix.as_deref()
}
pub fn max_tokens_per_chunk(&self) -> Option<usize> {
self.data.max_tokens_per_chunk
}
pub fn default_chunk_size(&self) -> usize {
self.data.default_chunk_size.unwrap_or(1000)
}
pub fn max_batch_size(&self) -> Option<usize> {
self.data.max_batch_size
}
pub fn max_tokens_param(&self) -> Option<isize> {
if self.data.require_max_tokens {
self.data.max_output_tokens
} else {
None
}
}
pub fn set_max_tokens(
&mut self,
max_output_tokens: Option<isize>,
require_max_tokens: bool,
) -> &mut Self {
match max_output_tokens {
None | Some(0) => self.data.max_output_tokens = None,
_ => self.data.max_output_tokens = max_output_tokens,
}
self.data.require_max_tokens = require_max_tokens;
self
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
let messages_len = messages.len();
messages
.iter()
.enumerate()
.map(|(i, v)| match &v.content {
MessageContent::Text(text) => {
if v.role.is_assistant() && i != messages_len - 1 {
estimate_token_length(&strip_think_tag(text))
} else {
estimate_token_length(text)
}
}
MessageContent::Array(list) => list
.iter()
.map(|v| match v {
MessageContentPart::Text { text } => estimate_token_length(text),
MessageContentPart::ImageUrl { .. } => 0,
})
.sum(),
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
estimate_token_length(text)
+ tool_results
.iter()
.map(|v| {
serde_json::to_string(v)
.map(|v| estimate_token_length(&v))
.unwrap_or_default()
})
.sum::<usize>()
}
})
.sum()
}
pub fn total_tokens(&self, messages: &[Message]) -> usize {
if messages.is_empty() {
return 0;
}
let num_messages = messages.len();
let message_tokens = self.messages_tokens(messages);
if messages[num_messages - 1].role.is_user() {
num_messages * PER_MESSAGES_TOKENS + message_tokens
} else {
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
}
}
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
if let Some(max_input_tokens) = self.data.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max_input_tokens limit")
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelData {
pub name: String,
#[serde(default = "default_model_type", rename = "type")]
pub model_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub real_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_input_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_price: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_price: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub patch: Option<Value>,
// chat-only properties
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<isize>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub require_max_tokens: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub supports_vision: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub supports_function_calling: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
no_stream: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
no_system_message: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system_prompt_prefix: Option<String>,
// embedding-only properties
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens_per_chunk: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_chunk_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_batch_size: Option<usize>,
}
impl ModelData {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
model_type: default_model_type(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderModels {
pub provider: String,
pub models: Vec<ModelData>,
}
fn default_model_type() -> String {
"chat".into()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelType {
Chat,
Embedding,
Reranker,
}
impl Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::Chat => write!(f, "chat"),
ModelType::Embedding => write!(f, "embedding"),
ModelType::Reranker => write!(f, "reranker"),
}
}
}
impl ModelType {
pub fn can_create_from_name(self) -> bool {
match self {
ModelType::Chat => true,
ModelType::Embedding => false,
ModelType::Reranker => true,
}
}
pub fn api_name(self) -> &'static str {
match self {
ModelType::Chat => "chat_completions",
ModelType::Embedding => "embeddings",
ModelType::Reranker => "rerank",
}
}
pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> {
match self {
ModelType::Chat => patch.chat_completions.as_ref(),
ModelType::Embedding => patch.embeddings.as_ref(),
ModelType::Reranker => patch.rerank.as_ref(),
}
}
}
fn stringify_option_value<T>(value: &Option<T>) -> String
where
T: Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}
+408
View File
@@ -0,0 +1,408 @@
use super::*;
use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.openai.com/v1";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
pub organization_id: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl OpenAIClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
OpenAIClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &OpenAIClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
if let Some(organization_id) = &self_.config.organization_id {
request_data.header("OpenAI-Organization", organization_id);
}
Ok(request_data)
}
fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{api_base}/embeddings");
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
if let Some(organization_id) = &self_.config.organization_id {
request_data.header("OpenAI-Organization", organization_id);
}
Ok(request_data)
}
pub async fn openai_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
openai_extract_chat_completions(&data)
}
pub async fn openai_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut call_id = String::new();
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let handle = |message: SseMessage| -> Result<bool> {
if message.data == "[DONE]" {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
normalize_function_id(&function_id),
))?;
}
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["choices"][0]["delta"]["content"]
.as_str()
.filter(|v| !v.is_empty())
{
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
handler.text(text)?;
} else if let Some(text) = data["choices"][0]["delta"]["reasoning_content"]
.as_str()
.or_else(|| data["choices"][0]["delta"]["reasoning"].as_str())
.filter(|v| !v.is_empty())
{
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
}
if let (Some(function), index, id) = (
data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(),
data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(),
data["choices"][0]["delta"]["tool_calls"][0]["id"]
.as_str()
.filter(|v| !v.is_empty()),
) {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default());
if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
normalize_function_id(&function_id),
))?;
}
function_name.clear();
function_arguments.clear();
function_id.clear();
call_id = maybe_call_id;
}
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
if name.starts_with(&function_name) {
function_name = name.to_string();
} else {
function_name.push_str(name);
}
}
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
function_arguments.push_str(arguments);
}
if let Some(id) = id {
function_id = id.to_string();
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
pub async fn openai_embeddings(
builder: RequestBuilder,
_model: &Model,
) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
data: Vec<EmbeddingsResBodyEmbedding>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
embedding: Vec<f32>,
}
pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
} = data;
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results,
text: _,
sequence,
}) => {
if !sequence {
let tool_calls: Vec<_> = tool_results
.iter()
.map(|tool_result| {
json!({
"id": tool_result.call.id,
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments.to_string(),
},
})
})
.collect();
let mut messages = vec![
json!({ "role": MessageRole::Assistant, "tool_calls": tool_calls }),
];
for tool_result in tool_results {
messages.push(json!({
"role": "tool",
"content": tool_result.output.to_string(),
"tool_call_id": tool_result.call.id,
}));
}
messages
} else {
tool_results.into_iter().flat_map(|tool_result| {
vec![
json!({
"role": MessageRole::Assistant,
"tool_calls": [
{
"id": tool_result.call.id,
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments.to_string(),
},
}
]
}),
json!({
"role": "tool",
"content": tool_result.output.to_string(),
"tool_call_id": tool_result.call.id,
})
]
}).collect()
}
}
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": strip_think_tag(&text) }
)]
}
_ => vec![json!({ "role": role, "content": content })],
}
})
.collect();
let mut body = json!({
"model": &model.real_name(),
"messages": messages,
});
if let Some(v) = model.max_tokens_param() {
if model
.patch()
.and_then(|v| v.get("body").and_then(|v| v.get("max_tokens")))
== Some(&Value::Null)
{
body["max_completion_tokens"] = v.into();
} else {
body["max_tokens"] = v.into();
}
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"type": "function",
"function": v,
})
})
.collect();
}
body
}
pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value {
json!({
"input": data.texts,
"model": model.real_name()
})
}
pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["choices"][0]["message"]["content"]
.as_str()
.unwrap_or_default();
let reasoning = data["choices"][0]["message"]["reasoning_content"]
.as_str()
.or_else(|| data["choices"][0]["message"]["reasoning"].as_str())
.unwrap_or_default()
.trim();
let mut tool_calls = vec![];
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
arguments,
Some(id.to_string()),
));
}
}
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let text = if !reasoning.is_empty() {
format!("<think>\n{reasoning}\n</think>\n\n{text}")
} else {
text.to_string()
};
let output = ChatCompletionsOutput {
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok(output)
}
fn normalize_function_id(value: &str) -> Option<String> {
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
+162
View File
@@ -0,0 +1,162 @@
use super::openai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAICompatibleConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl OpenAICompatibleClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 0] = [];
}
impl_client_trait!(
OpenAICompatibleClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(prepare_rerank, generic_rerank),
);
fn prepare_chat_completions(
self_: &OpenAICompatibleClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = format!("{api_base}/chat/completions");
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn prepare_embeddings(
self_: &OpenAICompatibleClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = format!("{api_base}/embeddings");
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = if self_.name().starts_with("ernie") {
format!("{api_base}/rerankers")
} else {
format!("{api_base}/rerank")
};
let body = generic_build_rerank_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result<String> {
let api_base = match self_.get_api_base() {
Ok(v) => v,
Err(err) => {
match OPENAI_COMPATIBLE_PROVIDERS
.into_iter()
.find_map(|(name, api_base)| {
if name == self_.model.client_name() {
Some(api_base.to_string())
} else {
None
}
}) {
Some(v) => v,
None => return Err(err),
}
}
};
Ok(api_base.trim_end_matches('/').to_string())
}
pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
let res = builder.send().await?;
let status = res.status();
let mut data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
if data.get("results").is_none() && data.get("data").is_some() {
if let Some(data_obj) = data.as_object_mut() {
if let Some(value) = data_obj.remove("data") {
data_obj.insert("results".to_string(), value);
}
}
}
let res_body: GenericRerankResBody =
serde_json::from_value(data).context("Invalid rerank data")?;
Ok(res_body.results)
}
#[derive(Deserialize)]
pub struct GenericRerankResBody {
pub results: RerankOutput,
}
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
let RerankData {
query,
documents,
top_n,
} = data;
let mut body = json!({
"model": model.real_name(),
"query": query,
"documents": documents,
});
if model.client_name().starts_with("voyageai") {
body["top_k"] = (*top_n).into()
} else {
body["top_n"] = (*top_n).into()
}
body
}
+296
View File
@@ -0,0 +1,296 @@
use super::{catch_error, ToolCall};
use crate::utils::AbortSignal;
use anyhow::{anyhow, bail, Context, Result};
use futures_util::{Stream, StreamExt};
use reqwest::RequestBuilder;
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde_json::Value;
use tokio::sync::mpsc::UnboundedSender;
pub struct SseHandler {
sender: UnboundedSender<SseEvent>,
abort_signal: AbortSignal,
buffer: String,
tool_calls: Vec<ToolCall>,
}
impl SseHandler {
pub fn new(sender: UnboundedSender<SseEvent>, abort_signal: AbortSignal) -> Self {
Self {
sender,
abort_signal,
buffer: String::new(),
tool_calls: Vec::new(),
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
// debug!("HandleText: {}", text);
if text.is_empty() {
return Ok(());
}
self.buffer.push_str(text);
let ret = self
.sender
.send(SseEvent::Text(text.to_string()))
.with_context(|| "Failed to send SseEvent:Text");
if let Err(err) = ret {
if self.abort_signal.aborted() {
return Ok(());
}
return Err(err);
}
Ok(())
}
pub fn done(&mut self) {
// debug!("HandleDone");
let ret = self.sender.send(SseEvent::Done);
if ret.is_err() {
if self.abort_signal.aborted() {
return;
}
warn!("Failed to send SseEvent:Done");
}
}
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
// debug!("HandleCall: {:?}", call);
self.tool_calls.push(call);
Ok(())
}
pub fn abort(&self) -> AbortSignal {
self.abort_signal.clone()
}
pub fn tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}
pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..
} = self;
(buffer, tool_calls)
}
}
#[derive(Debug)]
pub enum SseEvent {
Text(String),
Done,
}
#[derive(Debug)]
pub struct SseMessage {
#[allow(unused)]
pub event: String,
pub data: String,
}
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
where
F: FnMut(SseMessage) -> Result<bool>,
{
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let message = SseMessage {
event: message.event,
data: message.data,
};
if handle(message)? {
break;
}
}
Err(err) => {
match err {
EventSourceError::StreamEnded => {}
EventSourceError::InvalidStatusCode(status, res) => {
let text = res.text().await?;
let data: Value = match text.parse() {
Ok(data) => data,
Err(_) => {
bail!(
"Invalid response data: {text} (status: {})",
status.as_u16()
);
}
};
catch_error(&data, status.as_u16())?;
}
EventSourceError::InvalidContentType(header_value, res) => {
let text = res.text().await?;
bail!(
"Invalid response event-stream. content-type: {}, data: {text}",
header_value.to_str().unwrap_or_default()
);
}
_ => {
bail!("{}", err);
}
}
es.close();
}
}
}
Ok(())
}
pub async fn json_stream<S, F, E>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, E>> + Unpin,
F: FnMut(&str) -> Result<()>,
E: std::error::Error,
{
let mut parser = JsonStreamParser::default();
let mut unparsed_bytes = vec![];
while let Some(chunk_bytes) = stream.next().await {
let chunk_bytes =
chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?;
unparsed_bytes.extend(chunk_bytes);
match std::str::from_utf8(&unparsed_bytes) {
Ok(text) => {
parser.process(text, &mut handle)?;
unparsed_bytes.clear();
}
Err(_) => {
continue;
}
}
}
if !unparsed_bytes.is_empty() {
let text = std::str::from_utf8(&unparsed_bytes)?;
parser.process(text, &mut handle)?;
}
Ok(())
}
#[derive(Debug, Default)]
struct JsonStreamParser {
buffer: Vec<char>,
cursor: usize,
start: Option<usize>,
balances: Vec<char>,
quoting: bool,
escape: bool,
}
impl JsonStreamParser {
fn process<F>(&mut self, text: &str, handle: &mut F) -> Result<()>
where
F: FnMut(&str) -> Result<()>,
{
self.buffer.extend(text.chars());
for i in self.cursor..self.buffer.len() {
let ch = self.buffer[i];
if self.quoting {
if ch == '\\' {
self.escape = !self.escape;
} else {
if !self.escape && ch == '"' {
self.quoting = false;
}
self.escape = false;
}
continue;
}
match ch {
'"' => {
self.quoting = true;
self.escape = false;
}
'{' => {
if self.balances.is_empty() {
self.start = Some(i);
}
self.balances.push(ch);
}
'[' => {
if self.start.is_some() {
self.balances.push(ch);
}
}
'}' => {
self.balances.pop();
if self.balances.is_empty() {
if let Some(start) = self.start.take() {
let value: String = self.buffer[start..=i].iter().collect();
handle(&value)?;
}
}
}
']' => {
self.balances.pop();
}
_ => {}
}
}
self.cursor = self.buffer.len();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures_util::stream;
use rand::Rng;
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
let mut rng = rand::rng();
let len = text.len();
let cut1 = rng.random_range(1..len - 1);
let cut2 = rng.random_range(cut1 + 1..len);
let chunk1 = text.as_bytes()[..cut1].to_vec();
let chunk2 = text.as_bytes()[cut1..cut2].to_vec();
let chunk3 = text.as_bytes()[cut2..].to_vec();
vec![chunk1, chunk2, chunk3]
}
macro_rules! assert_json_stream {
($input:expr, $output:expr) => {
let chunks: Vec<_> = split_chunks($input)
.into_iter()
.map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk)))
.collect();
let stream = stream::iter(chunks);
let mut output = vec![];
let ret = json_stream(stream, |data| {
output.push(data.to_string());
Ok(())
})
.await;
assert!(ret.is_ok());
assert_eq!($output.replace("\r\n", "\n"), output.join("\n"))
};
}
#[tokio::test]
async fn test_json_stream_ndjson() {
let data = r#"{"key": "value"}
{"key": "value2"}
{"key": "value3"}"#;
assert_json_stream!(data, data);
}
#[tokio::test]
async fn test_json_stream_array() {
let input = r#"[
{"key": "value"},
{"key": "value2"},
{"key": "value3"},"#;
let output = r#"{"key": "value"}
{"key": "value2"}
{"key": "value3"}"#;
assert_json_stream!(input, output);
}
}
+537
View File
@@ -0,0 +1,537 @@
use super::access_token::*;
use super::claude::*;
use super::openai::*;
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{path::PathBuf, str::FromStr};
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl VertexAIClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptAction<'static>; 2] = [
("project_id", "Project ID", None),
("location", "Location", None),
];
}
#[async_trait::async_trait]
impl Client for VertexAIClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let model = self.model();
let model_category = ModelCategory::from_str(model.real_name())?;
let request_data = prepare_chat_completions(self, data, &model_category)?;
let builder = self.request_builder(client, request_data);
match model_category {
ModelCategory::Gemini => gemini_chat_completions(builder, model).await,
ModelCategory::Claude => claude_chat_completions(builder, model).await,
ModelCategory::Mistral => openai_chat_completions(builder, model).await,
}
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let model = self.model();
let model_category = ModelCategory::from_str(model.real_name())?;
let request_data = prepare_chat_completions(self, data, &model_category)?;
let builder = self.request_builder(client, request_data);
match model_category {
ModelCategory::Gemini => {
gemini_chat_completions_streaming(builder, handler, model).await
}
ModelCategory::Claude => {
claude_chat_completions_streaming(builder, handler, model).await
}
ModelCategory::Mistral => {
openai_chat_completions_streaming(builder, handler, model).await
}
}
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<Vec<Vec<f32>>> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let request_data = prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data);
embeddings(builder, self.model()).await
}
}
fn prepare_chat_completions(
self_: &VertexAIClient,
data: ChatCompletionsData,
model_category: &ModelCategory,
) -> Result<RequestData> {
let project_id = self_.get_project_id()?;
let location = self_.get_location()?;
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
};
let model_name = self_.model.real_name();
let url = match model_category {
ModelCategory::Gemini => {
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
format!("{base_url}/google/models/{model_name}:{func}")
}
ModelCategory::Claude => {
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
}
ModelCategory::Mistral => {
let func = match data.stream {
true => "streamRawPredict",
false => "rawPredict",
};
format!("{base_url}/mistralai/models/{model_name}:{func}")
}
};
let body = match model_category {
ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?,
ModelCategory::Claude => {
let mut body = claude_build_chat_completions_body(data, &self_.model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "vertex-2023-10-16".into();
body
}
ModelCategory::Mistral => {
let mut body = openai_build_chat_completions_body(data, &self_.model);
if let Some(body_obj) = body.as_object_mut() {
body_obj["model"] = strip_model_version(self_.model.real_name()).into();
}
body
}
};
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(access_token);
Ok(request_data)
}
fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let project_id = self_.get_project_id()?;
let location = self_.get_location()?;
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
};
let url = format!(
"{base_url}/google/models/{}:predict",
self_.model.real_name()
);
let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect();
let body = json!({
"instances": instances,
});
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(access_token);
Ok(request_data)
}
pub async fn gemini_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
gemini_extract_chat_completions_text(&data)
}
pub async fn gemini_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
} else {
let handle = |value: &str| -> Result<()> {
let data: Value = serde_json::from_str(value)?;
debug!("stream-data: {data}");
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for (i, part) in parts.iter().enumerate() {
if let Some(text) = part["text"].as_str() {
if i > 0 {
handler.text("\n\n")?;
}
handler.text(text)?;
} else if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
}
}
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked due to safety")
}
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body
.predictions
.into_iter()
.map(|v| v.embeddings.values)
.collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
predictions: Vec<EmbeddingsResBodyPrediction>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyPrediction {
embeddings: EmbeddingsResBodyPredictionEmbeddings,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyPredictionEmbeddings {
values: Vec<f32>,
}
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text_parts = vec![];
let mut tool_calls = vec![];
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for part in parts {
if let Some(text) = part["text"].as_str() {
text_parts.push(text);
}
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
}
}
}
let text = text_parts.join("\n\n");
if text.is_empty() && tool_calls.is_empty() {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked due to safety")
} else {
bail!("Invalid response data: {data}");
}
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
};
Ok(output)
}
pub fn gemini_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream: _,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "user",
_ => "model",
};
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"parts": [{ "text": text }]
})],
MessageContent::Array(list) => {
let parts: Vec<Value> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
},
})
.collect();
vec![json!({ "role": role, "parts": parts })]
},
MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => {
let model_parts: Vec<Value> = tool_results.iter().map(|tool_result| {
json!({
"functionCall": {
"name": tool_result.call.name,
"args": tool_result.call.arguments,
}
})
}).collect();
let function_parts: Vec<Value> = tool_results.into_iter().map(|tool_result| {
json!({
"functionResponse": {
"name": tool_result.call.name,
"response": {
"name": tool_result.call.name,
"content": tool_result.output,
}
}
})
}).collect();
vec![
json!({ "role": "model", "parts": model_parts }),
json!({ "role": "function", "parts": function_parts }),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(v) = system_message {
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
}
if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into();
}
if let Some(v) = temperature {
body["generationConfig"]["temperature"] = v.into();
}
if let Some(v) = top_p {
body["generationConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
let function_declarations: Vec<_> = functions
.into_iter()
.map(|function| {
if function.parameters.is_empty_properties() {
json!({
"name": function.name,
"description": function.description,
})
} else {
json!(function)
}
})
.collect();
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
}
Ok(body)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Gemini,
Claude,
Mistral,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("gemini") {
Ok(ModelCategory::Gemini)
} else if s.starts_with("claude") {
Ok(ModelCategory::Claude)
} else if s.starts_with("mistral") || s.starts_with("codestral") {
Ok(ModelCategory::Mistral)
} else {
unsupported_model!(s)
}
}
}
pub async fn prepare_gcloud_access_token(
client: &reqwest::Client,
client_name: &str,
adc_file: &Option<String>,
) -> Result<()> {
if !is_valid_access_token(client_name) {
let (token, expires_in) = fetch_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
set_access_token(client_name, token, expires_at.timestamp())
}
Ok(())
}
async fn fetch_access_token(
client: &reqwest::Client,
file: &Option<String>,
) -> Result<(String, i64)> {
let credentials = load_adc(file).await?;
let value: Value = client
.post("https://oauth2.googleapis.com/token")
.json(&credentials)
.send()
.await?
.json()
.await?;
if let (Some(access_token), Some(expires_in)) =
(value["access_token"].as_str(), value["expires_in"].as_i64())
{
Ok((access_token.to_string(), expires_in))
} else if let Some(err_msg) = value["error_description"].as_str() {
bail!("{err_msg}")
} else {
bail!("Invalid response data: {value}")
}
}
async fn load_adc(file: &Option<String>) -> Result<Value> {
let adc_file = file
.as_ref()
.map(PathBuf::from)
.or_else(default_adc_file)
.ok_or_else(|| anyhow!("No application_default_credentials.json"))?;
let data = tokio::fs::read_to_string(adc_file).await?;
let data: Value = serde_json::from_str(&data)?;
if let (Some(client_id), Some(client_secret), Some(refresh_token)) = (
data["client_id"].as_str(),
data["client_secret"].as_str(),
data["refresh_token"].as_str(),
) {
Ok(json!({
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}))
} else {
bail!("Invalid application_default_credentials.json")
}
}
#[cfg(not(windows))]
fn default_adc_file() -> Option<PathBuf> {
let mut path = dirs::home_dir()?;
path.push(".config");
path.push("gcloud");
path.push("application_default_credentials.json");
Some(path)
}
#[cfg(windows)]
fn default_adc_file() -> Option<PathBuf> {
let mut path = dirs::config_dir()?;
path.push("gcloud");
path.push("application_default_credentials.json");
Some(path)
}
fn strip_model_version(name: &str) -> &str {
match name.split_once('@') {
Some((v, _)) => v,
None => name,
}
}