refactor: Made the oauth module more generic so it can support loopback OAuth (not just manual)

This commit is contained in:
2026-03-12 13:28:09 -06:00
parent 73cbe16ec1
commit 063e198f96
4 changed files with 204 additions and 48 deletions
Generated
+1
View File
@@ -3243,6 +3243,7 @@ dependencies = [
"tokio-stream", "tokio-stream",
"unicode-segmentation", "unicode-segmentation",
"unicode-width 0.2.2", "unicode-width 0.2.2",
"url",
"urlencoding", "urlencoding",
"uuid", "uuid",
"which", "which",
+1
View File
@@ -98,6 +98,7 @@ gman = "0.3.0"
clap_complete_nushell = "4.5.9" clap_complete_nushell = "4.5.9"
open = "5" open = "5"
rand = "0.9.0" rand = "0.9.0"
url = "2.5.8"
[dependencies.reqwest] [dependencies.reqwest]
version = "0.12.0" version = "0.12.0"
+4
View File
@@ -29,6 +29,10 @@ impl OAuthProvider for ClaudeOAuthProvider {
"org:create_api_key user:profile user:inference" "org:create_api_key user:profile user:inference"
} }
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("code", "true")]
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> { fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)] vec![("anthropic-beta", BETA_HEADER)]
} }
+193 -43
View File
@@ -1,18 +1,27 @@
use super::ClientConfig; use super::ClientConfig;
use super::access_token::{is_valid_access_token, set_access_token}; use super::access_token::{is_valid_access_token, set_access_token};
use crate::config::Config; use crate::config::Config;
use anyhow::{Result, bail}; use anyhow::{Result, anyhow, bail};
use base64::Engine; use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc; use chrono::Utc;
use inquire::Text; use inquire::Text;
use reqwest::Client as ReqwestClient; use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Value, json}; use serde_json::Value;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs; use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use url::Url;
use uuid::Uuid; use uuid::Uuid;
pub enum TokenRequestFormat {
Json,
FormUrlEncoded,
}
pub trait OAuthProvider: Send + Sync { pub trait OAuthProvider: Send + Sync {
fn provider_name(&self) -> &str; fn provider_name(&self) -> &str;
fn client_id(&self) -> &str; fn client_id(&self) -> &str;
@@ -21,6 +30,22 @@ pub trait OAuthProvider: Send + Sync {
fn redirect_uri(&self) -> &str; fn redirect_uri(&self) -> &str;
fn scopes(&self) -> &str; fn scopes(&self) -> &str;
fn client_secret(&self) -> Option<&str> {
None
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::Json
}
fn uses_localhost_redirect(&self) -> bool {
false
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> { fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![] vec![]
} }
@@ -37,7 +62,7 @@ pub struct OAuthTokens {
pub expires_at: i64, pub expires_at: i64,
} }
pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> { pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> {
let random_bytes: [u8; 32] = rand::random::<[u8; 32]>(); let random_bytes: [u8; 32] = rand::random::<[u8; 32]>();
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes); let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
@@ -47,11 +72,21 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let state = Uuid::new_v4().to_string(); let state = Uuid::new_v4().to_string();
let encoded_scopes = urlencoding::encode(provider.scopes()); let redirect_uri = if provider.uses_localhost_redirect() {
let encoded_redirect = urlencoding::encode(provider.redirect_uri()); let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
let uri = format!("http://127.0.0.1:{port}/callback");
drop(listener);
uri
} else {
provider.redirect_uri().to_string()
};
let authorize_url = format!( let encoded_scopes = urlencoding::encode(provider.scopes());
"{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", let encoded_redirect = urlencoding::encode(&redirect_uri);
let mut authorize_url = format!(
"{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
provider.authorize_url(), provider.authorize_url(),
provider.client_id(), provider.client_id(),
encoded_scopes, encoded_scopes,
@@ -60,6 +95,14 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
state state
); );
for (key, value) in provider.extra_authorize_params() {
authorize_url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
println!( println!(
"\nOpen this URL to authenticate with {} (client '{}'):\n", "\nOpen this URL to authenticate with {} (client '{}'):\n",
provider.provider_name(), provider.provider_name(),
@@ -69,14 +112,16 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let _ = open::that(&authorize_url); let _ = open::that(&authorize_url);
let (code, returned_state) = if provider.uses_localhost_redirect() {
listen_for_oauth_callback(&redirect_uri)?
} else {
let input = Text::new("Paste the authorization code:").prompt()?; let input = Text::new("Paste the authorization code:").prompt()?;
let parts: Vec<&str> = input.splitn(2, '#').collect(); let parts: Vec<&str> = input.splitn(2, '#').collect();
if parts.len() != 2 { if parts.len() != 2 {
bail!("Invalid authorization code format. Expected format: <code>#<state>"); bail!("Invalid authorization code format. Expected format: <code>#<state>");
} }
let code = parts[0]; (parts[0].to_string(), parts[1].to_string())
let returned_state = parts[1]; };
if returned_state != state { if returned_state != state {
bail!( bail!(
@@ -86,32 +131,32 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
} }
let client = ReqwestClient::new(); let client = ReqwestClient::new();
let mut request = client.post(provider.token_url()).json(&json!({ let request = build_token_request(
"grant_type": "authorization_code", &client,
"client_id": provider.client_id(), provider,
"code": code, &[
"code_verifier": code_verifier, ("grant_type", "authorization_code"),
"redirect_uri": provider.redirect_uri(), ("client_id", provider.client_id()),
"state": state, ("code", &code),
})); ("code_verifier", &code_verifier),
("redirect_uri", &redirect_uri),
for (key, value) in provider.extra_token_headers() { ("state", &state),
request = request.header(key, value); ],
} );
let response: Value = request.send().await?.json().await?; let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"] let access_token = response["access_token"]
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))? .ok_or_else(|| anyhow!("Missing access_token in response: {response}"))?
.to_string(); .to_string();
let refresh_token = response["refresh_token"] let refresh_token = response["refresh_token"]
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))? .ok_or_else(|| anyhow!("Missing refresh_token in response: {response}"))?
.to_string(); .to_string();
let expires_in = response["expires_in"] let expires_in = response["expires_in"]
.as_i64() .as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?; .ok_or_else(|| anyhow!("Missing expires_in in response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in; let expires_at = Utc::now().timestamp() + expires_in;
@@ -150,33 +195,33 @@ fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
pub async fn refresh_oauth_token( pub async fn refresh_oauth_token(
client: &ReqwestClient, client: &ReqwestClient,
provider: &dyn OAuthProvider, provider: &impl OAuthProvider,
client_name: &str, client_name: &str,
tokens: &OAuthTokens, tokens: &OAuthTokens,
) -> Result<OAuthTokens> { ) -> Result<OAuthTokens> {
let mut request = client.post(provider.token_url()).json(&json!({ let request = build_token_request(
"grant_type": "refresh_token", client,
"client_id": provider.client_id(), provider,
"refresh_token": tokens.refresh_token, &[
})); ("grant_type", "refresh_token"),
("client_id", provider.client_id()),
for (key, value) in provider.extra_token_headers() { ("refresh_token", &tokens.refresh_token),
request = request.header(key, value); ],
} );
let response: Value = request.send().await?.json().await?; let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"] let access_token = response["access_token"]
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))? .ok_or_else(|| anyhow!("Missing access_token in refresh response: {response}"))?
.to_string(); .to_string();
let refresh_token = response["refresh_token"] let refresh_token = response["refresh_token"]
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))? .map(|s| s.to_string())
.to_string(); .unwrap_or_else(|| tokens.refresh_token.clone());
let expires_in = response["expires_in"] let expires_in = response["expires_in"]
.as_i64() .as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?; .ok_or_else(|| anyhow!("Missing expires_in in refresh response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in; let expires_at = Utc::now().timestamp() + expires_in;
@@ -216,9 +261,110 @@ pub async fn prepare_oauth_access_token(
Ok(true) Ok(true)
} }
pub fn get_oauth_provider(provider_type: &str) -> Option<impl OAuthProvider> { fn build_token_request(
client: &ReqwestClient,
provider: &(impl OAuthProvider + ?Sized),
params: &[(&str, &str)],
) -> RequestBuilder {
let mut request = match provider.token_request_format() {
TokenRequestFormat::Json => {
let body: serde_json::Map<String, Value> = params
.iter()
.map(|(k, v)| (k.to_string(), Value::String(v.to_string())))
.collect();
if let Some(secret) = provider.client_secret() {
let mut body = body;
body.insert(
"client_secret".to_string(),
Value::String(secret.to_string()),
);
client.post(provider.token_url()).json(&body)
} else {
client.post(provider.token_url()).json(&body)
}
}
TokenRequestFormat::FormUrlEncoded => {
let mut form: HashMap<String, String> = params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
if let Some(secret) = provider.client_secret() {
form.insert("client_secret".to_string(), secret.to_string());
}
client.post(provider.token_url()).form(&form)
}
};
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
request
}
fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> {
let url: Url = redirect_uri.parse()?;
let host = url.host_str().unwrap_or("127.0.0.1");
let port = url
.port()
.ok_or_else(|| anyhow!("No port in redirect URI"))?;
let path = url.path();
println!("Waiting for OAuth callback on {redirect_uri} ...\n");
let listener = TcpListener::bind(format!("{host}:{port}"))?;
let (mut stream, _) = listener.accept()?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let request_path = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow!("Malformed HTTP request from OAuth callback"))?;
let full_url = format!("http://{host}:{port}{request_path}");
let parsed: Url = full_url.parse()?;
if !parsed.path().starts_with(path) {
bail!("Unexpected callback path: {}", parsed.path());
}
let code = parsed
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| {
let error = parsed
.query_pairs()
.find(|(k, _)| k == "error")
.map(|(_, v)| v.to_string())
.unwrap_or_else(|| "unknown".to_string());
anyhow!("OAuth callback returned error: {error}")
})?;
let returned_state = parsed
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow!("Missing state parameter in OAuth callback"))?;
let response_body = "<html><body><h2>Authentication successful!</h2><p>You can close this tab and return to your terminal.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes())?;
Ok((code, returned_state))
}
pub fn get_oauth_provider(provider_type: &str) -> Option<Box<dyn OAuthProvider>> {
match provider_type { match provider_type {
"claude" => Some(super::claude_oauth::ClaudeOAuthProvider), "claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)),
"gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)),
_ => None, _ => None,
} }
} }
@@ -263,7 +409,11 @@ fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Opti
"openai-compatible", "openai-compatible",
None, None,
), ),
ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None), ClientConfig::GeminiConfig(c) => (
c.name.as_deref().unwrap_or("gemini"),
"gemini",
c.auth.as_deref(),
),
ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None), ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None),
ClientConfig::AzureOpenAIConfig(c) => ( ClientConfig::AzureOpenAIConfig(c) => (
c.name.as_deref().unwrap_or("azure-openai"), c.name.as_deref().unwrap_or("azure-openai"),