163 lines
4.2 KiB
Rust
163 lines
4.2 KiB
Rust
use super::openai::*;
|
|
use super::*;
|
|
|
|
use anyhow::{Context, Result};
|
|
use reqwest::RequestBuilder;
|
|
use serde::Deserialize;
|
|
use serde_json::{Value, json};
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct OpenAICompatibleConfig {
|
|
pub name: Option<String>,
|
|
pub api_base: Option<String>,
|
|
pub api_key: Option<String>,
|
|
#[serde(default)]
|
|
pub models: Vec<ModelData>,
|
|
pub patch: Option<RequestPatch>,
|
|
pub extra: Option<ExtraConfig>,
|
|
}
|
|
|
|
impl OpenAICompatibleClient {
|
|
config_get_fn!(api_base, get_api_base);
|
|
config_get_fn!(api_key, get_api_key);
|
|
|
|
create_client_config!([]);
|
|
}
|
|
|
|
impl_client_trait!(
|
|
OpenAICompatibleClient,
|
|
(
|
|
prepare_chat_completions,
|
|
openai_chat_completions,
|
|
openai_chat_completions_streaming
|
|
),
|
|
(prepare_embeddings, openai_embeddings),
|
|
(prepare_rerank, generic_rerank),
|
|
);
|
|
|
|
fn prepare_chat_completions(
|
|
self_: &OpenAICompatibleClient,
|
|
data: ChatCompletionsData,
|
|
) -> Result<RequestData> {
|
|
let api_key = self_.get_api_key().ok();
|
|
let api_base = get_api_base_ext(self_)?;
|
|
|
|
let url = format!("{api_base}/chat/completions");
|
|
|
|
let body = openai_build_chat_completions_body(data, &self_.model);
|
|
|
|
let mut request_data = RequestData::new(url, body);
|
|
|
|
if let Some(api_key) = api_key {
|
|
request_data.bearer_auth(api_key);
|
|
}
|
|
|
|
Ok(request_data)
|
|
}
|
|
|
|
fn prepare_embeddings(
|
|
self_: &OpenAICompatibleClient,
|
|
data: &EmbeddingsData,
|
|
) -> Result<RequestData> {
|
|
let api_key = self_.get_api_key().ok();
|
|
let api_base = get_api_base_ext(self_)?;
|
|
|
|
let url = format!("{api_base}/embeddings");
|
|
|
|
let body = openai_build_embeddings_body(data, &self_.model);
|
|
|
|
let mut request_data = RequestData::new(url, body);
|
|
|
|
if let Some(api_key) = api_key {
|
|
request_data.bearer_auth(api_key);
|
|
}
|
|
|
|
Ok(request_data)
|
|
}
|
|
|
|
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
|
|
let api_key = self_.get_api_key().ok();
|
|
let api_base = get_api_base_ext(self_)?;
|
|
|
|
let url = if self_.name().starts_with("ernie") {
|
|
format!("{api_base}/rerankers")
|
|
} else {
|
|
format!("{api_base}/rerank")
|
|
};
|
|
|
|
let body = generic_build_rerank_body(data, &self_.model);
|
|
|
|
let mut request_data = RequestData::new(url, body);
|
|
|
|
if let Some(api_key) = api_key {
|
|
request_data.bearer_auth(api_key);
|
|
}
|
|
|
|
Ok(request_data)
|
|
}
|
|
|
|
fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result<String> {
|
|
let api_base = match self_.get_api_base() {
|
|
Ok(v) => v,
|
|
Err(err) => {
|
|
match OPENAI_COMPATIBLE_PROVIDERS
|
|
.into_iter()
|
|
.find_map(|(name, api_base)| {
|
|
if name == self_.model.client_name() {
|
|
Some(api_base.to_string())
|
|
} else {
|
|
None
|
|
}
|
|
}) {
|
|
Some(v) => v,
|
|
None => return Err(err),
|
|
}
|
|
}
|
|
};
|
|
Ok(api_base.trim_end_matches('/').to_string())
|
|
}
|
|
|
|
pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
|
let res = builder.send().await?;
|
|
let status = res.status();
|
|
let mut data: Value = res.json().await?;
|
|
if !status.is_success() {
|
|
catch_error(&data, status.as_u16())?;
|
|
}
|
|
if data.get("results").is_none()
|
|
&& data.get("data").is_some()
|
|
&& let Some(data_obj) = data.as_object_mut()
|
|
&& let Some(value) = data_obj.remove("data")
|
|
{
|
|
data_obj.insert("results".to_string(), value);
|
|
}
|
|
let res_body: GenericRerankResBody =
|
|
serde_json::from_value(data).context("Invalid rerank data")?;
|
|
Ok(res_body.results)
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct GenericRerankResBody {
|
|
pub results: RerankOutput,
|
|
}
|
|
|
|
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
|
|
let RerankData {
|
|
query,
|
|
documents,
|
|
top_n,
|
|
} = data;
|
|
|
|
let mut body = json!({
|
|
"model": model.real_name(),
|
|
"query": query,
|
|
"documents": documents,
|
|
});
|
|
if model.client_name().starts_with("voyageai") {
|
|
body["top_k"] = (*top_n).into()
|
|
} else {
|
|
body["top_n"] = (*top_n).into()
|
|
}
|
|
body
|
|
}
|