use super::access_token::*; use super::claude::*; use super::openai::*; use super::*; use anyhow::{Context, Result, anyhow, bail}; use chrono::{Duration, Utc}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{Value, json}; use std::{path::PathBuf, str::FromStr}; #[derive(Debug, Clone, Deserialize, Default)] pub struct VertexAIConfig { pub name: Option, pub project_id: Option, pub location: Option, pub adc_file: Option, #[serde(default)] pub models: Vec, pub patch: Option, pub extra: Option, } impl VertexAIClient { config_get_fn!(project_id, get_project_id); config_get_fn!(location, get_location); create_client_config!([ ("project_id", "Project ID", None, false), ("location", "Location", None, false), ]); } #[async_trait::async_trait] impl Client for VertexAIClient { client_common_fns!(); async fn chat_completions_inner( &self, client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let model = self.model(); let model_category = ModelCategory::from_str(model.real_name())?; let request_data = prepare_chat_completions(self, data, &model_category)?; let builder = self.request_builder(client, request_data); match model_category { ModelCategory::Gemini => gemini_chat_completions(builder, model).await, ModelCategory::Claude => claude_chat_completions(builder, model).await, ModelCategory::Mistral => openai_chat_completions(builder, model).await, } } async fn chat_completions_streaming_inner( &self, client: &ReqwestClient, handler: &mut SseHandler, data: ChatCompletionsData, ) -> Result<()> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let model = self.model(); let model_category = ModelCategory::from_str(model.real_name())?; let request_data = prepare_chat_completions(self, data, &model_category)?; let builder = self.request_builder(client, request_data); match model_category { ModelCategory::Gemini => { gemini_chat_completions_streaming(builder, handler, model).await } ModelCategory::Claude => { claude_chat_completions_streaming(builder, handler, model).await } ModelCategory::Mistral => { openai_chat_completions_streaming(builder, handler, model).await } } } async fn embeddings_inner( &self, client: &ReqwestClient, data: &EmbeddingsData, ) -> Result>> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let request_data = prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data); embeddings(builder, self.model()).await } } fn prepare_chat_completions( self_: &VertexAIClient, data: ChatCompletionsData, model_category: &ModelCategory, ) -> Result { let project_id = self_.get_project_id()?; let location = self_.get_location()?; let access_token = get_access_token(self_.name())?; let base_url = if location == "global" { format!( "https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers" ) } else { format!( "https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers" ) }; let model_name = self_.model.real_name(); let url = match model_category { ModelCategory::Gemini => { let func = match data.stream { true => "streamGenerateContent", false => "generateContent", }; format!("{base_url}/google/models/{model_name}:{func}") } ModelCategory::Claude => { format!("{base_url}/anthropic/models/{model_name}:streamRawPredict") } ModelCategory::Mistral => { let func = match data.stream { true => "streamRawPredict", false => "rawPredict", }; format!("{base_url}/mistralai/models/{model_name}:{func}") } }; let body = match model_category { ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?, ModelCategory::Claude => { let mut body = claude_build_chat_completions_body(data, &self_.model)?; if let Some(body_obj) = body.as_object_mut() { body_obj.remove("model"); } body["anthropic_version"] = "vertex-2023-10-16".into(); body } ModelCategory::Mistral => { let mut body = openai_build_chat_completions_body(data, &self_.model); if let Some(body_obj) = body.as_object_mut() { body_obj["model"] = strip_model_version(self_.model.real_name()).into(); } body } }; let mut request_data = RequestData::new(url, body); request_data.bearer_auth(access_token); Ok(request_data) } fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result { let project_id = self_.get_project_id()?; let location = self_.get_location()?; let access_token = get_access_token(self_.name())?; let base_url = if location == "global" { format!( "https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers" ) } else { format!( "https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers" ) }; let url = format!( "{base_url}/google/models/{}:predict", self_.model.real_name() ); let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect(); let body = json!({ "instances": instances, }); let mut request_data = RequestData::new(url, body); request_data.bearer_auth(access_token); Ok(request_data) } pub async fn gemini_chat_completions( builder: RequestBuilder, _model: &Model, ) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; if !status.is_success() { catch_error(&data, status.as_u16())?; } debug!("non-stream-data: {data}"); gemini_extract_chat_completions_text(&data) } pub async fn gemini_chat_completions_streaming( builder: RequestBuilder, handler: &mut SseHandler, _model: &Model, ) -> Result<()> { let res = builder.send().await?; let status = res.status(); if !status.is_success() { let data: Value = res.json().await?; catch_error(&data, status.as_u16())?; } else { let handle = |value: &str| -> Result<()> { let data: Value = serde_json::from_str(value)?; debug!("stream-data: {data}"); if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { for (i, part) in parts.iter().enumerate() { if let Some(text) = part["text"].as_str() { if i > 0 { handler.text("\n\n")?; } handler.text(text)?; } else if let (Some(name), Some(args)) = ( part["functionCall"]["name"].as_str(), part["functionCall"]["args"].as_object(), ) { let thought_signature = part["thoughtSignature"] .as_str() .or_else(|| part["thought_signature"].as_str()) .map(|s| s.to_string()); handler.tool_call( ToolCall::new(name.to_string(), json!(args), None) .with_thought_signature(thought_signature), )?; } } } else if let Some("SAFETY") = data["promptFeedback"]["blockReason"] .as_str() .or_else(|| data["candidates"][0]["finishReason"].as_str()) { bail!("Blocked due to safety") } Ok(()) }; json_stream(res.bytes_stream(), handle).await?; } Ok(()) } async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; if !status.is_success() { catch_error(&data, status.as_u16())?; } let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid embeddings data")?; let output = res_body .predictions .into_iter() .map(|v| v.embeddings.values) .collect(); Ok(output) } #[derive(Deserialize)] struct EmbeddingsResBody { predictions: Vec, } #[derive(Deserialize)] struct EmbeddingsResBodyPrediction { embeddings: EmbeddingsResBodyPredictionEmbeddings, } #[derive(Deserialize)] struct EmbeddingsResBodyPredictionEmbeddings { values: Vec, } fn gemini_extract_chat_completions_text(data: &Value) -> Result { let mut text_parts = vec![]; let mut tool_calls = vec![]; if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { for part in parts { if let Some(text) = part["text"].as_str() { text_parts.push(text); } if let (Some(name), Some(args)) = ( part["functionCall"]["name"].as_str(), part["functionCall"]["args"].as_object(), ) { let thought_signature = part["thoughtSignature"] .as_str() .or_else(|| part["thought_signature"].as_str()) .map(|s| s.to_string()); tool_calls.push( ToolCall::new(name.to_string(), json!(args), None) .with_thought_signature(thought_signature), ); } } } let text = text_parts.join("\n\n"); if text.is_empty() && tool_calls.is_empty() { if let Some("SAFETY") = data["promptFeedback"]["blockReason"] .as_str() .or_else(|| data["candidates"][0]["finishReason"].as_str()) { bail!("Blocked due to safety") } else { bail!("Invalid response data: {data}"); } } let output = ChatCompletionsOutput { text, tool_calls }; Ok(output) } pub fn gemini_build_chat_completions_body( data: ChatCompletionsData, model: &Model, ) -> Result { let ChatCompletionsData { mut messages, temperature, top_p, functions, stream: _, } = data; let system_message = extract_system_message(&mut messages); let mut network_image_urls = vec![]; let contents: Vec = messages .into_iter() .flat_map(|message| { let Message { role, content } = message; let role = match role { MessageRole::User => "user", _ => "model", }; match content { MessageContent::Text(text) => vec![json!({ "role": role, "parts": [{ "text": text }] })], MessageContent::Array(list) => { let parts: Vec = list .into_iter() .map(|item| match item { MessageContentPart::Text { text } => json!({"text": text}), MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => { if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) { json!({ "inline_data": { "mime_type": mime_type, "data": data } }) } else { network_image_urls.push(url.clone()); json!({ "url": url }) } }, }) .collect(); vec![json!({ "role": role, "parts": parts })] }, MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => { let model_parts: Vec = tool_results.iter().map(|tool_result| { let mut part = json!({ "functionCall": { "name": tool_result.call.name, "args": tool_result.call.arguments, } }); if let Some(sig) = &tool_result.call.thought_signature { part["thoughtSignature"] = json!(sig); } part }).collect(); let function_parts: Vec = tool_results.into_iter().map(|tool_result| { json!({ "functionResponse": { "name": tool_result.call.name, "response": { "name": tool_result.call.name, "content": tool_result.output, } } }) }).collect(); vec![ json!({ "role": "model", "parts": model_parts }), json!({ "role": "function", "parts": function_parts }), ] } } }) .collect(); if !network_image_urls.is_empty() { bail!( "The model does not support network images: {:?}", network_image_urls ); } let mut body = json!({ "contents": contents, "generationConfig": {} }); if let Some(v) = system_message { body["systemInstruction"] = json!({ "parts": [{"text": v }] }); } if let Some(v) = model.max_tokens_param() { body["generationConfig"]["maxOutputTokens"] = v.into(); } if let Some(v) = temperature { body["generationConfig"]["temperature"] = v.into(); } if let Some(v) = top_p { body["generationConfig"]["topP"] = v.into(); } if let Some(functions) = functions { // Gemini doesn't support functions with parameters that have empty properties, so we need to patch it. let function_declarations: Vec<_> = functions .into_iter() .map(|function| { if function.parameters.is_empty_properties() { json!({ "name": function.name, "description": function.description, }) } else { json!(function) } }) .collect(); body["tools"] = json!([{ "functionDeclarations": function_declarations }]); } Ok(body) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ModelCategory { Gemini, Claude, Mistral, } impl FromStr for ModelCategory { type Err = anyhow::Error; fn from_str(s: &str) -> std::result::Result { if s.starts_with("gemini") { Ok(ModelCategory::Gemini) } else if s.starts_with("claude") { Ok(ModelCategory::Claude) } else if s.starts_with("mistral") || s.starts_with("codestral") { Ok(ModelCategory::Mistral) } else { unsupported_model!(s) } } } pub async fn prepare_gcloud_access_token( client: &reqwest::Client, client_name: &str, adc_file: &Option, ) -> Result<()> { if !is_valid_access_token(client_name) { let (token, expires_in) = fetch_access_token(client, adc_file) .await .with_context(|| "Failed to fetch access token")?; let expires_at = Utc::now() + Duration::try_seconds(expires_in) .ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?; set_access_token(client_name, token, expires_at.timestamp()) } Ok(()) } async fn fetch_access_token( client: &reqwest::Client, file: &Option, ) -> Result<(String, i64)> { let credentials = load_adc(file).await?; let value: Value = client .post("https://oauth2.googleapis.com/token") .json(&credentials) .send() .await? .json() .await?; if let (Some(access_token), Some(expires_in)) = (value["access_token"].as_str(), value["expires_in"].as_i64()) { Ok((access_token.to_string(), expires_in)) } else if let Some(err_msg) = value["error_description"].as_str() { bail!("{err_msg}") } else { bail!("Invalid response data: {value}") } } async fn load_adc(file: &Option) -> Result { let adc_file = file .as_ref() .map(PathBuf::from) .or_else(default_adc_file) .ok_or_else(|| anyhow!("No application_default_credentials.json"))?; let data = tokio::fs::read_to_string(adc_file).await?; let data: Value = serde_json::from_str(&data)?; if let (Some(client_id), Some(client_secret), Some(refresh_token)) = ( data["client_id"].as_str(), data["client_secret"].as_str(), data["refresh_token"].as_str(), ) { Ok(json!({ "client_id": client_id, "client_secret": client_secret, "refresh_token": refresh_token, "grant_type": "refresh_token", })) } else { bail!("Invalid application_default_credentials.json") } } #[cfg(not(windows))] fn default_adc_file() -> Option { let mut path = dirs::home_dir()?; path.push(".config"); path.push("gcloud"); path.push("application_default_credentials.json"); Some(path) } #[cfg(windows)] fn default_adc_file() -> Option { let mut path = dirs::config_dir()?; path.push("gcloud"); path.push("application_default_credentials.json"); Some(path) } fn strip_model_version(name: &str) -> &str { match name.split_once('@') { Some((v, _)) => v, None => name, } }