From 83581d9d187ef93abe987bf6dfa5d9f9ed63b872 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 11 Mar 2026 14:18:55 -0600 Subject: [PATCH] feat: Initial scaffolding work for Gemini OAuth support (Claude generated, Alex updated) --- Cargo.lock | 1 + Cargo.toml | 1 + README.md | 2 +- config.example.yaml | 2 + docs/clients/CLIENTS.md | 7 +- src/client/claude_oauth.rs | 4 + src/client/gemini.rs | 153 ++++++++++++++++++------ src/client/gemini_oauth.rs | 50 ++++++++ src/client/mod.rs | 1 + src/client/oauth.rs | 233 ++++++++++++++++++++++++++++++------- src/main.rs | 4 +- src/repl/mod.rs | 2 +- 12 files changed, 379 insertions(+), 81 deletions(-) create mode 100644 src/client/gemini_oauth.rs diff --git a/Cargo.lock b/Cargo.lock index 09e752d..e87257b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3243,6 +3243,7 @@ dependencies = [ "tokio-stream", "unicode-segmentation", "unicode-width 0.2.2", + "url", "urlencoding", "uuid", "which", diff --git a/Cargo.toml b/Cargo.toml index 43bf852..3bc9980 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,6 +98,7 @@ gman = "0.3.0" clap_complete_nushell = "4.5.9" open = "5" rand = "0.9.0" +url = "2.5.8" [dependencies.reqwest] version = "0.12.0" diff --git a/README.md b/README.md index 026a718..df5e59e 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ loki --list-secrets ### Authentication Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key (set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max -subscribers), you can authenticate with your existing subscription instead: +subscribers, Google Gemini), you can authenticate with your existing subscription instead: ```yaml # In your config.yaml diff --git a/config.example.yaml b/config.example.yaml index 4b843d2..b6acb0a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -192,6 +192,8 @@ clients: - type: gemini api_base: https://generativelanguage.googleapis.com/v1beta api_key: '{{GEMINI_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault + auth: null # When set to 'oauth', Loki will use OAuth instead of an API key + # Authenticate with `loki --authenticate` or `.authenticate` in the REPL patch: chat_completions: '.*': diff --git a/docs/clients/CLIENTS.md b/docs/clients/CLIENTS.md index 576a6df..330e6d5 100644 --- a/docs/clients/CLIENTS.md +++ b/docs/clients/CLIENTS.md @@ -137,8 +137,10 @@ loki --authenticate Alternatively, you can use the REPL command `.authenticate`. -This opens your browser for the OAuth authorization flow. After authorizing, paste the authorization code back into -the terminal. Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes them when they expire. +This opens your browser for the OAuth authorization flow. Depending on the provider, Loki will either start a +temporary localhost server to capture the callback automatically (e.g. Gemini) or ask you to paste the authorization +code back into the terminal (e.g. Claude). Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes +them when they expire. **Step 3: Use normally** @@ -153,6 +155,7 @@ loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!" ### Providers That Support OAuth * Claude +* Gemini ## Extra Settings Loki also lets you customize some extra settings for interacting with APIs: diff --git a/src/client/claude_oauth.rs b/src/client/claude_oauth.rs index fed4883..e552ba8 100644 --- a/src/client/claude_oauth.rs +++ b/src/client/claude_oauth.rs @@ -29,6 +29,10 @@ impl OAuthProvider for ClaudeOAuthProvider { "org:create_api_key user:profile user:inference" } + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![("code", "true")] + } + fn extra_token_headers(&self) -> Vec<(&str, &str)> { vec![("anthropic-beta", BETA_HEADER)] } diff --git a/src/client/gemini.rs b/src/client/gemini.rs index c0bafaa..2be9c66 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -1,8 +1,11 @@ +use super::access_token::get_access_token; +use super::gemini_oauth::GeminiOAuthProvider; +use super::oauth; use super::vertexai::*; use super::*; -use anyhow::{Context, Result}; -use reqwest::RequestBuilder; +use anyhow::{Context, Result, bail}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{Value, json}; @@ -13,6 +16,7 @@ pub struct GeminiConfig { pub name: Option, pub api_key: Option, pub api_base: Option, + pub auth: Option, #[serde(default)] pub models: Vec, pub patch: Option, @@ -23,25 +27,64 @@ impl GeminiClient { config_get_fn!(api_key, get_api_key); config_get_fn!(api_base, get_api_base); - create_client_config!([("api_key", "API Key", None, true)]); + create_oauth_supported_client_config!(); } -impl_client_trait!( - GeminiClient, - ( - prepare_chat_completions, - gemini_chat_completions, - gemini_chat_completions_streaming - ), - (prepare_embeddings, embeddings), - (noop_prepare_rerank, noop_rerank), -); +#[async_trait::async_trait] +impl Client for GeminiClient { + client_common_fns!(); -fn prepare_chat_completions( + fn supports_oauth(&self) -> bool { + self.config.auth.as_deref() == Some("oauth") + } + + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + gemini_chat_completions(builder, self.model()).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()> { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + gemini_chat_completions_streaming(builder, handler, self.model()).await + } + + async fn embeddings_inner( + &self, + client: &ReqwestClient, + data: &EmbeddingsData, + ) -> Result { + let request_data = prepare_embeddings(self, client, data).await?; + let builder = self.request_builder(client, request_data); + embeddings(builder, self.model()).await + } + + async fn rerank_inner( + &self, + client: &ReqwestClient, + data: &RerankData, + ) -> Result { + let request_data = noop_prepare_rerank(self, data)?; + let builder = self.request_builder(client, request_data); + noop_rerank(builder, self.model()).await + } +} + +async fn prepare_chat_completions( self_: &GeminiClient, + client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); @@ -59,26 +102,61 @@ fn prepare_chat_completions( ); let body = gemini_build_chat_completions_body(data, &self_.model)?; - let mut request_data = RequestData::new(url, body); - request_data.header("x-goog-api-key", api_key); + let uses_oauth = self_.config.auth.as_deref() == Some("oauth"); + + if uses_oauth { + let provider = GeminiOAuthProvider; + let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?; + if !ready { + bail!( + "OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}", + self_.name(), + self_.name() + ); + } + let token = get_access_token(self_.name())?; + request_data.bearer_auth(token); + } else if let Ok(api_key) = self_.get_api_key() { + request_data.header("x-goog-api-key", api_key); + } else { + bail!( + "No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.", + self_.name(), + self_.name() + ); + } Ok(request_data) } -fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result { - let api_key = self_.get_api_key()?; +async fn prepare_embeddings( + self_: &GeminiClient, + client: &ReqwestClient, + data: &EmbeddingsData, +) -> Result { let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); - let url = format!( - "{}/models/{}:batchEmbedContents?key={}", - api_base.trim_end_matches('/'), - self_.model.real_name(), - api_key - ); + let uses_oauth = self_.config.auth.as_deref() == Some("oauth"); + + let url = if uses_oauth { + format!( + "{}/models/{}:batchEmbedContents", + api_base.trim_end_matches('/'), + self_.model.real_name(), + ) + } else { + let api_key = self_.get_api_key()?; + format!( + "{}/models/{}:batchEmbedContents?key={}", + api_base.trim_end_matches('/'), + self_.model.real_name(), + api_key + ) + }; let model_id = format!("models/{}", self_.model.real_name()); @@ -89,21 +167,28 @@ fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result &str { + "gemini" + } + + fn client_id(&self) -> &str { + GEMINI_CLIENT_ID + } + + fn authorize_url(&self) -> &str { + "https://accounts.google.com/o/oauth2/v2/auth" + } + + fn token_url(&self) -> &str { + "https://oauth2.googleapis.com/token" + } + + fn redirect_uri(&self) -> &str { + "" + } + + fn scopes(&self) -> &str { + "https://www.googleapis.com/auth/cloud-platform.readonly https://www.googleapis.com/auth/userinfo.email" + } + + fn client_secret(&self) -> Option<&str> { + Some(GEMINI_CLIENT_SECRET) + } + + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![("access_type", "offline"), ("prompt", "consent")] + } + + fn token_request_format(&self) -> TokenRequestFormat { + TokenRequestFormat::FormUrlEncoded + } + + fn uses_localhost_redirect(&self) -> bool { + true + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index c1b4aa3..0cd9365 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,7 @@ mod access_token; mod claude_oauth; mod common; +mod gemini_oauth; mod message; pub mod oauth; #[macro_use] diff --git a/src/client/oauth.rs b/src/client/oauth.rs index 8ee117c..f4c97ba 100644 --- a/src/client/oauth.rs +++ b/src/client/oauth.rs @@ -8,11 +8,20 @@ use chrono::Utc; use inquire::Text; use reqwest::Client as ReqwestClient; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::Value; use sha2::{Digest, Sha256}; +use std::collections::HashMap; use std::fs; +use std::io::{BufRead, BufReader, Write}; +use std::net::TcpListener; +use url::Url; use uuid::Uuid; +pub enum TokenRequestFormat { + Json, + FormUrlEncoded, +} + pub trait OAuthProvider: Send + Sync { fn provider_name(&self) -> &str; fn client_id(&self) -> &str; @@ -21,6 +30,22 @@ pub trait OAuthProvider: Send + Sync { fn redirect_uri(&self) -> &str; fn scopes(&self) -> &str; + fn client_secret(&self) -> Option<&str> { + None + } + + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![] + } + + fn token_request_format(&self) -> TokenRequestFormat { + TokenRequestFormat::Json + } + + fn uses_localhost_redirect(&self) -> bool { + false + } + fn extra_token_headers(&self) -> Vec<(&str, &str)> { vec![] } @@ -37,7 +62,7 @@ pub struct OAuthTokens { pub expires_at: i64, } -pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> { +pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> { let random_bytes: [u8; 32] = rand::random::<[u8; 32]>(); let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes); @@ -47,11 +72,22 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> let state = Uuid::new_v4().to_string(); - let encoded_scopes = urlencoding::encode(provider.scopes()); - let encoded_redirect = urlencoding::encode(provider.redirect_uri()); + let redirect_uri = if provider.uses_localhost_redirect() { + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + let uri = format!("http://127.0.0.1:{port}/callback"); + // Drop the listener so run_oauth_flow can re-bind below + drop(listener); + uri + } else { + provider.redirect_uri().to_string() + }; - let authorize_url = format!( - "{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", + let encoded_scopes = urlencoding::encode(provider.scopes()); + let encoded_redirect = urlencoding::encode(&redirect_uri); + + let mut authorize_url = format!( + "{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", provider.authorize_url(), provider.client_id(), encoded_scopes, @@ -60,6 +96,14 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> state ); + for (key, value) in provider.extra_authorize_params() { + authorize_url.push_str(&format!( + "&{}={}", + urlencoding::encode(key), + urlencoding::encode(value) + )); + } + println!( "\nOpen this URL to authenticate with {} (client '{}'):\n", provider.provider_name(), @@ -69,14 +113,16 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> let _ = open::that(&authorize_url); - let input = Text::new("Paste the authorization code:").prompt()?; - - let parts: Vec<&str> = input.splitn(2, '#').collect(); - if parts.len() != 2 { - bail!("Invalid authorization code format. Expected format: #"); - } - let code = parts[0]; - let returned_state = parts[1]; + let (code, returned_state) = if provider.uses_localhost_redirect() { + listen_for_oauth_callback(&redirect_uri)? + } else { + let input = Text::new("Paste the authorization code:").prompt()?; + let parts: Vec<&str> = input.splitn(2, '#').collect(); + if parts.len() != 2 { + bail!("Invalid authorization code format. Expected format: #"); + } + (parts[0].to_string(), parts[1].to_string()) + }; if returned_state != state { bail!( @@ -86,18 +132,18 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> } let client = ReqwestClient::new(); - let mut request = client.post(provider.token_url()).json(&json!({ - "grant_type": "authorization_code", - "client_id": provider.client_id(), - "code": code, - "code_verifier": code_verifier, - "redirect_uri": provider.redirect_uri(), - "state": state, - })); - - for (key, value) in provider.extra_token_headers() { - request = request.header(key, value); - } + let request = build_token_request( + &client, + provider, + &[ + ("grant_type", "authorization_code"), + ("client_id", provider.client_id()), + ("code", &code), + ("code_verifier", &code_verifier), + ("redirect_uri", &redirect_uri), + ("state", &state), + ], + ); let response: Value = request.send().await?.json().await?; @@ -150,19 +196,19 @@ fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> { pub async fn refresh_oauth_token( client: &ReqwestClient, - provider: &dyn OAuthProvider, + provider: &impl OAuthProvider, client_name: &str, tokens: &OAuthTokens, ) -> Result { - let mut request = client.post(provider.token_url()).json(&json!({ - "grant_type": "refresh_token", - "client_id": provider.client_id(), - "refresh_token": tokens.refresh_token, - })); - - for (key, value) in provider.extra_token_headers() { - request = request.header(key, value); - } + let request = build_token_request( + client, + provider, + &[ + ("grant_type", "refresh_token"), + ("client_id", provider.client_id()), + ("refresh_token", &tokens.refresh_token), + ], + ); let response: Value = request.send().await?.json().await?; @@ -172,8 +218,8 @@ pub async fn refresh_oauth_token( .to_string(); let refresh_token = response["refresh_token"] .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))? - .to_string(); + .map(|s| s.to_string()) + .unwrap_or_else(|| tokens.refresh_token.clone()); let expires_in = response["expires_in"] .as_i64() .ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?; @@ -216,9 +262,110 @@ pub async fn prepare_oauth_access_token( Ok(true) } -pub fn get_oauth_provider(provider_type: &str) -> Option { +fn build_token_request( + client: &ReqwestClient, + provider: &(impl OAuthProvider + ?Sized), + params: &[(&str, &str)], +) -> reqwest::RequestBuilder { + let mut request = match provider.token_request_format() { + TokenRequestFormat::Json => { + let body: serde_json::Map = params + .iter() + .map(|(k, v)| (k.to_string(), Value::String(v.to_string()))) + .collect(); + if let Some(secret) = provider.client_secret() { + let mut body = body; + body.insert( + "client_secret".to_string(), + Value::String(secret.to_string()), + ); + client.post(provider.token_url()).json(&body) + } else { + client.post(provider.token_url()).json(&body) + } + } + TokenRequestFormat::FormUrlEncoded => { + let mut form: HashMap = params + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + if let Some(secret) = provider.client_secret() { + form.insert("client_secret".to_string(), secret.to_string()); + } + client.post(provider.token_url()).form(&form) + } + }; + + for (key, value) in provider.extra_token_headers() { + request = request.header(key, value); + } + + request +} + +fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> { + let url: Url = redirect_uri.parse()?; + let host = url.host_str().unwrap_or("127.0.0.1"); + let port = url + .port() + .ok_or_else(|| anyhow::anyhow!("No port in redirect URI"))?; + let path = url.path(); + + println!("Waiting for OAuth callback on {redirect_uri} ...\n"); + + let listener = TcpListener::bind(format!("{host}:{port}"))?; + let (mut stream, _) = listener.accept()?; + + let mut reader = BufReader::new(&stream); + let mut request_line = String::new(); + reader.read_line(&mut request_line)?; + + let request_path = request_line + .split_whitespace() + .nth(1) + .ok_or_else(|| anyhow::anyhow!("Malformed HTTP request from OAuth callback"))?; + + let full_url = format!("http://{host}:{port}{request_path}"); + let parsed: Url = full_url.parse()?; + + let response_body = "

Authentication successful!

You can close this tab and return to your terminal.

"; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response_body.len(), + response_body + ); + stream.write_all(response.as_bytes())?; + + if !parsed.path().starts_with(path) { + bail!("Unexpected callback path: {}", parsed.path()); + } + + let code = parsed + .query_pairs() + .find(|(k, _)| k == "code") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| { + let error = parsed + .query_pairs() + .find(|(k, _)| k == "error") + .map(|(_, v)| v.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + anyhow::anyhow!("OAuth callback returned error: {error}") + })?; + + let returned_state = parsed + .query_pairs() + .find(|(k, _)| k == "state") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| anyhow::anyhow!("Missing state parameter in OAuth callback"))?; + + Ok((code, returned_state)) +} + +pub fn get_oauth_provider(provider_type: &str) -> Option> { match provider_type { - "claude" => Some(super::claude_oauth::ClaudeOAuthProvider), + "claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)), + "gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)), _ => None, } } @@ -263,7 +410,11 @@ fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Opti "openai-compatible", None, ), - ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None), + ClientConfig::GeminiConfig(c) => ( + c.name.as_deref().unwrap_or("gemini"), + "gemini", + c.auth.as_deref(), + ), ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None), ClientConfig::AzureOpenAIConfig(c) => ( c.name.as_deref().unwrap_or("azure-openai"), diff --git a/src/main.rs b/src/main.rs index f788a90..d465a58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -86,7 +86,7 @@ async fn main() -> Result<()> { if let Some(client_arg) = &cli.authenticate { let config = Config::init_bare()?; let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?; - oauth::run_oauth_flow(&provider, &client_name).await?; + oauth::run_oauth_flow(&*provider, &client_name).await?; return Ok(()); } @@ -517,7 +517,7 @@ fn init_console_logger( fn resolve_oauth_client( explicit: Option<&str>, clients: &[ClientConfig], -) -> Result<(String, impl OAuthProvider)> { +) -> Result<(String, Box)> { if let Some(name) = explicit { let provider_type = oauth::resolve_provider_type(name, clients) .ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?; diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 452a0fd..5edc0e2 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -437,7 +437,7 @@ pub async fn run_repl_command( } let clients = config.read().clients.clone(); let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?; - oauth::run_oauth_flow(&provider, &client_name).await?; + oauth::run_oauth_flow(&*provider, &client_name).await?; } ".prompt" => match args { Some(text) => {