From b2dbdfb4b116d785e5389635b8690be731a17f74 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 12 Mar 2026 13:29:47 -0600 Subject: [PATCH] feat: Support for Gemini OAuth --- src/client/gemini.rs | 153 ++++++++++++++++++++++++++++--------- src/client/gemini_oauth.rs | 49 ++++++++++++ src/client/mod.rs | 1 + src/main.rs | 4 +- src/repl/mod.rs | 5 +- 5 files changed, 174 insertions(+), 38 deletions(-) create mode 100644 src/client/gemini_oauth.rs diff --git a/src/client/gemini.rs b/src/client/gemini.rs index c0bafaa..2be9c66 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,8 +1,11 @@ +use super::access_token::get_access_token; +use super::gemini_oauth::GeminiOAuthProvider; +use super::oauth; use super::vertexai::*; use super::*; -use anyhow::{Context, Result}; -use reqwest::RequestBuilder; +use anyhow::{Context, Result, bail}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{Value, json}; @@ -13,6 +16,7 @@ pub struct GeminiConfig { pub name: Option, pub api_key: Option, pub api_base: Option, + pub auth: Option, #[serde(default)] pub models: Vec, pub patch: Option, @@ -23,25 +27,64 @@ impl GeminiClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - create_client_config!([("api_key", "API Key", None, true)]); + create_oauth_supported_client_config!(); } -impl_client_trait!( - GeminiClient, - ( - prepare_chat_completions, - gemini_chat_completions, - gemini_chat_completions_streaming - ), - (prepare_embeddings, embeddings), - (noop_prepare_rerank, noop_rerank), -); +#[async_trait::async_trait] +impl Client for GeminiClient { + client_common_fns!(); -fn prepare_chat_completions( + fn supports_oauth(&self) -> bool { + self.config.auth.as_deref() == Some("oauth") + } + + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + gemini_chat_completions(builder, self.model()).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()> { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + gemini_chat_completions_streaming(builder, handler, self.model()).await + } + + async fn embeddings_inner( + &self, + client: &ReqwestClient, + data: &EmbeddingsData, + ) -> Result { + let request_data = prepare_embeddings(self, client, data).await?; + let builder = self.request_builder(client, request_data); + embeddings(builder, self.model()).await + } + + async fn rerank_inner( + &self, + client: &ReqwestClient, + data: &RerankData, + ) -> Result { + let request_data = noop_prepare_rerank(self, data)?; + let builder = self.request_builder(client, request_data); + noop_rerank(builder, self.model()).await + } +} + +async fn prepare_chat_completions( self_: &GeminiClient, + client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); @@ -59,26 +102,61 @@ fn prepare_chat_completions( ); let body = gemini_build_chat_completions_body(data, &self_.model)?; - let mut request_data = RequestData::new(url, body); - request_data.header("x-goog-api-key", api_key); + let uses_oauth = self_.config.auth.as_deref() == Some("oauth"); + + if uses_oauth { + let provider = GeminiOAuthProvider; + let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?; + if !ready { + bail!( + "OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}", + self_.name(), + self_.name() + ); + } + let token = get_access_token(self_.name())?; + request_data.bearer_auth(token); + } else if let Ok(api_key) = self_.get_api_key() { + request_data.header("x-goog-api-key", api_key); + } else { + bail!( + "No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.", + self_.name(), + self_.name() + ); + } Ok(request_data) } -fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result { - let api_key = self_.get_api_key()?; +async fn prepare_embeddings( + self_: &GeminiClient, + client: &ReqwestClient, + data: &EmbeddingsData, +) -> Result { let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); - let url = format!( - "{}/models/{}:batchEmbedContents?key={}", - api_base.trim_end_matches('/'), - self_.model.real_name(), - api_key - ); + let uses_oauth = self_.config.auth.as_deref() == Some("oauth"); + + let url = if uses_oauth { + format!( + "{}/models/{}:batchEmbedContents", + api_base.trim_end_matches('/'), + self_.model.real_name(), + ) + } else { + let api_key = self_.get_api_key()?; + format!( + "{}/models/{}:batchEmbedContents?key={}", + api_base.trim_end_matches('/'), + self_.model.real_name(), + api_key + ) + }; let model_id = format!("models/{}", self_.model.real_name()); @@ -89,21 +167,28 @@ fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result &str { + "gemini" + } + + fn client_id(&self) -> &str { + GEMINI_CLIENT_ID + } + + fn authorize_url(&self) -> &str { + "https://accounts.google.com/o/oauth2/v2/auth" + } + + fn token_url(&self) -> &str { + "https://oauth2.googleapis.com/token" + } + + fn redirect_uri(&self) -> &str { + "" + } + + fn scopes(&self) -> &str { + "https://www.googleapis.com/auth/generative-language.peruserquota https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/userinfo.email" + } + + fn client_secret(&self) -> Option<&str> { + Some(GEMINI_CLIENT_SECRET) + } + + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![("access_type", "offline"), ("prompt", "consent")] + } + + fn token_request_format(&self) -> TokenRequestFormat { + TokenRequestFormat::FormUrlEncoded + } + + fn uses_localhost_redirect(&self) -> bool { + true + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index c1b4aa3..0cd9365 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,7 @@ mod access_token; mod claude_oauth; mod common; +mod gemini_oauth; mod message; pub mod oauth; #[macro_use] diff --git a/src/main.rs b/src/main.rs index f788a90..d465a58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -86,7 +86,7 @@ async fn main() -> Result<()> { if let Some(client_arg) = &cli.authenticate { let config = Config::init_bare()?; let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?; - oauth::run_oauth_flow(&provider, &client_name).await?; + oauth::run_oauth_flow(&*provider, &client_name).await?; return Ok(()); } @@ -517,7 +517,7 @@ fn init_console_logger( fn resolve_oauth_client( explicit: Option<&str>, clients: &[ClientConfig], -) -> Result<(String, impl OAuthProvider)> { +) -> Result<(String, Box)> { if let Some(name) = explicit { let provider_type = oauth::resolve_provider_type(name, clients) .ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?; diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 452a0fd..8c0652f 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -428,7 +428,8 @@ pub async fn run_repl_command( None => println!("Usage: .model "), }, ".authenticate" => { - let client = init_client(config, None)?; + let current_model = config.read().current_model().clone(); + let client = init_client(config, Some(current_model))?; if !client.supports_oauth() { bail!( "Client '{}' doesn't either support OAuth or isn't configured to use it (i.e. uses an API key instead)", @@ -437,7 +438,7 @@ pub async fn run_repl_command( } let clients = config.read().clients.clone(); let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?; - oauth::run_oauth_flow(&provider, &client_name).await?; + oauth::run_oauth_flow(&*provider, &client_name).await?; } ".prompt" => match args { Some(text) => {