diff --git a/Cargo.lock b/Cargo.lock index 09e752d..e87257b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3243,6 +3243,7 @@ dependencies = [ "tokio-stream", "unicode-segmentation", "unicode-width 0.2.2", + "url", "urlencoding", "uuid", "which", diff --git a/Cargo.toml b/Cargo.toml index 43bf852..3bc9980 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,6 +98,7 @@ gman = "0.3.0" clap_complete_nushell = "4.5.9" open = "5" rand = "0.9.0" +url = "2.5.8" [dependencies.reqwest] version = "0.12.0" diff --git a/src/client/claude_oauth.rs b/src/client/claude_oauth.rs index fed4883..e552ba8 100644 --- a/src/client/claude_oauth.rs +++ b/src/client/claude_oauth.rs @@ -29,6 +29,10 @@ impl OAuthProvider for ClaudeOAuthProvider { "org:create_api_key user:profile user:inference" } + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![("code", "true")] + } + fn extra_token_headers(&self) -> Vec<(&str, &str)> { vec![("anthropic-beta", BETA_HEADER)] } diff --git a/src/client/oauth.rs b/src/client/oauth.rs index 8ee117c..6390dc3 100644 --- a/src/client/oauth.rs +++ b/src/client/oauth.rs @@ -1,18 +1,27 @@ use super::ClientConfig; use super::access_token::{is_valid_access_token, set_access_token}; use crate::config::Config; -use anyhow::{Result, bail}; +use anyhow::{Result, anyhow, bail}; use base64::Engine; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use chrono::Utc; use inquire::Text; -use reqwest::Client as ReqwestClient; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::Value; use sha2::{Digest, Sha256}; +use std::collections::HashMap; use std::fs; +use std::io::{BufRead, BufReader, Write}; +use std::net::TcpListener; +use url::Url; use uuid::Uuid; +pub enum TokenRequestFormat { + Json, + FormUrlEncoded, +} + pub trait OAuthProvider: Send + Sync { fn provider_name(&self) -> &str; fn client_id(&self) -> &str; @@ -21,6 +30,22 @@ pub trait OAuthProvider: Send + Sync { fn redirect_uri(&self) -> &str; fn scopes(&self) -> &str; + fn client_secret(&self) -> Option<&str> { + None + } + + fn extra_authorize_params(&self) -> Vec<(&str, &str)> { + vec![] + } + + fn token_request_format(&self) -> TokenRequestFormat { + TokenRequestFormat::Json + } + + fn uses_localhost_redirect(&self) -> bool { + false + } + fn extra_token_headers(&self) -> Vec<(&str, &str)> { vec![] } @@ -37,7 +62,7 @@ pub struct OAuthTokens { pub expires_at: i64, } -pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> { +pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> { let random_bytes: [u8; 32] = rand::random::<[u8; 32]>(); let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes); @@ -47,11 +72,21 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> let state = Uuid::new_v4().to_string(); - let encoded_scopes = urlencoding::encode(provider.scopes()); - let encoded_redirect = urlencoding::encode(provider.redirect_uri()); + let redirect_uri = if provider.uses_localhost_redirect() { + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + let uri = format!("http://127.0.0.1:{port}/callback"); + drop(listener); + uri + } else { + provider.redirect_uri().to_string() + }; - let authorize_url = format!( - "{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", + let encoded_scopes = urlencoding::encode(provider.scopes()); + let encoded_redirect = urlencoding::encode(&redirect_uri); + + let mut authorize_url = format!( + "{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", provider.authorize_url(), provider.client_id(), encoded_scopes, @@ -60,6 +95,14 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> state ); + for (key, value) in provider.extra_authorize_params() { + authorize_url.push_str(&format!( + "&{}={}", + urlencoding::encode(key), + urlencoding::encode(value) + )); + } + println!( "\nOpen this URL to authenticate with {} (client '{}'):\n", provider.provider_name(), @@ -69,14 +112,16 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> let _ = open::that(&authorize_url); - let input = Text::new("Paste the authorization code:").prompt()?; - - let parts: Vec<&str> = input.splitn(2, '#').collect(); - if parts.len() != 2 { - bail!("Invalid authorization code format. Expected format: #"); - } - let code = parts[0]; - let returned_state = parts[1]; + let (code, returned_state) = if provider.uses_localhost_redirect() { + listen_for_oauth_callback(&redirect_uri)? + } else { + let input = Text::new("Paste the authorization code:").prompt()?; + let parts: Vec<&str> = input.splitn(2, '#').collect(); + if parts.len() != 2 { + bail!("Invalid authorization code format. Expected format: #"); + } + (parts[0].to_string(), parts[1].to_string()) + }; if returned_state != state { bail!( @@ -86,32 +131,32 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> } let client = ReqwestClient::new(); - let mut request = client.post(provider.token_url()).json(&json!({ - "grant_type": "authorization_code", - "client_id": provider.client_id(), - "code": code, - "code_verifier": code_verifier, - "redirect_uri": provider.redirect_uri(), - "state": state, - })); - - for (key, value) in provider.extra_token_headers() { - request = request.header(key, value); - } + let request = build_token_request( + &client, + provider, + &[ + ("grant_type", "authorization_code"), + ("client_id", provider.client_id()), + ("code", &code), + ("code_verifier", &code_verifier), + ("redirect_uri", &redirect_uri), + ("state", &state), + ], + ); let response: Value = request.send().await?.json().await?; let access_token = response["access_token"] .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))? + .ok_or_else(|| anyhow!("Missing access_token in response: {response}"))? .to_string(); let refresh_token = response["refresh_token"] .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))? + .ok_or_else(|| anyhow!("Missing refresh_token in response: {response}"))? .to_string(); let expires_in = response["expires_in"] .as_i64() - .ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?; + .ok_or_else(|| anyhow!("Missing expires_in in response: {response}"))?; let expires_at = Utc::now().timestamp() + expires_in; @@ -150,33 +195,33 @@ fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> { pub async fn refresh_oauth_token( client: &ReqwestClient, - provider: &dyn OAuthProvider, + provider: &impl OAuthProvider, client_name: &str, tokens: &OAuthTokens, ) -> Result { - let mut request = client.post(provider.token_url()).json(&json!({ - "grant_type": "refresh_token", - "client_id": provider.client_id(), - "refresh_token": tokens.refresh_token, - })); - - for (key, value) in provider.extra_token_headers() { - request = request.header(key, value); - } + let request = build_token_request( + client, + provider, + &[ + ("grant_type", "refresh_token"), + ("client_id", provider.client_id()), + ("refresh_token", &tokens.refresh_token), + ], + ); let response: Value = request.send().await?.json().await?; let access_token = response["access_token"] .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))? + .ok_or_else(|| anyhow!("Missing access_token in refresh response: {response}"))? .to_string(); let refresh_token = response["refresh_token"] .as_str() - .ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))? - .to_string(); + .map(|s| s.to_string()) + .unwrap_or_else(|| tokens.refresh_token.clone()); let expires_in = response["expires_in"] .as_i64() - .ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?; + .ok_or_else(|| anyhow!("Missing expires_in in refresh response: {response}"))?; let expires_at = Utc::now().timestamp() + expires_in; @@ -216,9 +261,110 @@ pub async fn prepare_oauth_access_token( Ok(true) } -pub fn get_oauth_provider(provider_type: &str) -> Option { +fn build_token_request( + client: &ReqwestClient, + provider: &(impl OAuthProvider + ?Sized), + params: &[(&str, &str)], +) -> RequestBuilder { + let mut request = match provider.token_request_format() { + TokenRequestFormat::Json => { + let body: serde_json::Map = params + .iter() + .map(|(k, v)| (k.to_string(), Value::String(v.to_string()))) + .collect(); + if let Some(secret) = provider.client_secret() { + let mut body = body; + body.insert( + "client_secret".to_string(), + Value::String(secret.to_string()), + ); + client.post(provider.token_url()).json(&body) + } else { + client.post(provider.token_url()).json(&body) + } + } + TokenRequestFormat::FormUrlEncoded => { + let mut form: HashMap = params + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + if let Some(secret) = provider.client_secret() { + form.insert("client_secret".to_string(), secret.to_string()); + } + client.post(provider.token_url()).form(&form) + } + }; + + for (key, value) in provider.extra_token_headers() { + request = request.header(key, value); + } + + request +} + +fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> { + let url: Url = redirect_uri.parse()?; + let host = url.host_str().unwrap_or("127.0.0.1"); + let port = url + .port() + .ok_or_else(|| anyhow!("No port in redirect URI"))?; + let path = url.path(); + + println!("Waiting for OAuth callback on {redirect_uri} ...\n"); + + let listener = TcpListener::bind(format!("{host}:{port}"))?; + let (mut stream, _) = listener.accept()?; + + let mut reader = BufReader::new(&stream); + let mut request_line = String::new(); + reader.read_line(&mut request_line)?; + + let request_path = request_line + .split_whitespace() + .nth(1) + .ok_or_else(|| anyhow!("Malformed HTTP request from OAuth callback"))?; + + let full_url = format!("http://{host}:{port}{request_path}"); + let parsed: Url = full_url.parse()?; + + if !parsed.path().starts_with(path) { + bail!("Unexpected callback path: {}", parsed.path()); + } + + let code = parsed + .query_pairs() + .find(|(k, _)| k == "code") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| { + let error = parsed + .query_pairs() + .find(|(k, _)| k == "error") + .map(|(_, v)| v.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + anyhow!("OAuth callback returned error: {error}") + })?; + + let returned_state = parsed + .query_pairs() + .find(|(k, _)| k == "state") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| anyhow!("Missing state parameter in OAuth callback"))?; + + let response_body = "

Authentication successful!

You can close this tab and return to your terminal.

"; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response_body.len(), + response_body + ); + stream.write_all(response.as_bytes())?; + + Ok((code, returned_state)) +} + +pub fn get_oauth_provider(provider_type: &str) -> Option> { match provider_type { - "claude" => Some(super::claude_oauth::ClaudeOAuthProvider), + "claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)), + "gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)), _ => None, } } @@ -263,7 +409,11 @@ fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Opti "openai-compatible", None, ), - ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None), + ClientConfig::GeminiConfig(c) => ( + c.name.as_deref().unwrap_or("gemini"), + "gemini", + c.auth.as_deref(), + ), ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None), ClientConfig::AzureOpenAIConfig(c) => ( c.name.as_deref().unwrap_or("azure-openai"),