use super::vertexai::*; use super::*; use anyhow::{Context, Result}; use reqwest::RequestBuilder; use serde::Deserialize; use serde_json::{Value, json}; const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta"; #[derive(Debug, Clone, Deserialize, Default)] pub struct GeminiConfig { pub name: Option, pub api_key: Option, pub api_base: Option, #[serde(default)] pub models: Vec, pub patch: Option, pub extra: Option, } impl GeminiClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); create_client_config!([("api_key", "API Key", None, true)]); } impl_client_trait!( GeminiClient, ( prepare_chat_completions, gemini_chat_completions, gemini_chat_completions_streaming ), (prepare_embeddings, embeddings), (noop_prepare_rerank, noop_rerank), ); fn prepare_chat_completions( self_: &GeminiClient, data: ChatCompletionsData, ) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); let func = match data.stream { true => "streamGenerateContent", false => "generateContent", }; let url = format!( "{}/models/{}:{}", api_base.trim_end_matches('/'), self_.model.real_name(), func ); let body = gemini_build_chat_completions_body(data, &self_.model)?; let mut request_data = RequestData::new(url, body); request_data.header("x-goog-api-key", api_key); Ok(request_data) } fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); let url = format!( "{}/models/{}:batchEmbedContents?key={}", api_base.trim_end_matches('/'), self_.model.real_name(), api_key ); let model_id = format!("models/{}", self_.model.real_name()); let requests: Vec<_> = data .texts .iter() .map(|text| { json!({ "model": model_id, "content": { "parts": [ { "text": text } ] }, }) }) .collect(); let body = json!({ "requests": requests, }); let request_data = RequestData::new(url, body); Ok(request_data) } async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; if !status.is_success() { catch_error(&data, status.as_u16())?; } let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid embeddings data")?; let output = res_body .embeddings .into_iter() .map(|embedding| embedding.values) .collect(); Ok(output) } #[derive(Deserialize)] struct EmbeddingsResBody { embeddings: Vec, } #[derive(Deserialize)] struct EmbeddingsResBodyEmbedding { values: Vec, }