From 83f66e106115ce72cacef84c834e640a864943c3 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 11 Mar 2026 11:10:48 -0600 Subject: [PATCH] feat: Support OAuth authentication flows for Claude --- Cargo.lock | 37 +++++ Cargo.toml | 3 +- README.md | 20 +++ config.example.yaml | 1 + docs/clients/CLIENTS.md | 81 ++++++++++- src/cli/mod.rs | 3 + src/client/claude.rs | 73 ++++++++-- src/client/claude_oauth.rs | 39 ++++++ src/client/common.rs | 8 -- src/client/mod.rs | 2 + src/client/oauth.rs | 279 +++++++++++++++++++++++++++++++++++++ src/config/mod.rs | 8 ++ src/main.rs | 45 +++++- 13 files changed, 567 insertions(+), 32 deletions(-) create mode 100644 src/client/claude_oauth.rs create mode 100644 src/client/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index 560aaaf..09e752d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2869,6 +2869,15 @@ dependencies = [ "serde", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + [[package]] name = "is-macro" version = "0.3.7" @@ -2892,6 +2901,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_executable" version = "1.0.5" @@ -3193,6 +3212,7 @@ dependencies = [ "log4rs", "nu-ansi-term", "num_cpus", + "open", "os_info", "parking_lot", "path-absolutize", @@ -3869,6 +3889,17 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "open" +version = "5.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb73a7fa3799b198970490a51174027ba0d4ec504b03cd08caf513d40024bc" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl" version = "0.10.75" @@ -4039,6 +4070,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "pem" version = "3.0.6" diff --git a/Cargo.toml b/Cargo.toml index 44bcf6c..43bf852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,6 +96,8 @@ colored = "3.0.0" clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] } gman = "0.3.0" clap_complete_nushell = "4.5.9" +open = "5" +rand = "0.9.0" [dependencies.reqwest] version = "0.12.0" @@ -126,7 +128,6 @@ arboard = { version = "3.3.0", default-features = false } [dev-dependencies] pretty_assertions = "1.4.0" -rand = "0.9.0" [[bin]] name = "loki" diff --git a/README.md b/README.md index 4d4a315..e1c2316 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Coming from [AIChat](https://github.com/sigoden/aichat)? Follow the [migration g * [Todo System](./docs/TODO-SYSTEM.md): Built-in task tracking for improved agent reliability with smaller models. * [Environment Variables](./docs/ENVIRONMENT-VARIABLES.md): Override and customize your Loki configuration at runtime with environment variables. * [Client Configurations](./docs/clients/CLIENTS.md): Configuration instructions for various LLM providers. + * [Authentication (API Key & OAuth)](./docs/clients/CLIENTS.md#authentication): Authenticate with API keys or OAuth for subscription-based access. * [Patching API Requests](./docs/clients/PATCHES.md): Learn how to patch API requests for advanced customization. * [Custom Themes](./docs/THEMES.md): Change the look and feel of Loki to your preferences with custom themes. * [History](#history): A history of how Loki came to be. @@ -150,6 +151,25 @@ guide you through the process when you first attempt to access the vault. So, to loki --list-secrets ``` +### Authentication +Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key +(set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max +subscribers), you can authenticate with your existing subscription instead: + +```yaml +# In your config.yaml +clients: + - type: claude + name: my-claude-oauth + auth: oauth # Indicate you want to authenticate with OAuth instead of an API key +``` + +```sh +loki --authenticate my-claude-oauth +``` + +For full details, see the [authentication documentation](./docs/clients/CLIENTS.md#authentication). + ### Tab-Completions You can also enable tab completions to make using Loki easier. To do so, add the following to your shell profile: ```shell diff --git a/config.example.yaml b/config.example.yaml index edeac0f..aadda92 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -210,6 +210,7 @@ clients: - type: claude api_base: https://api.anthropic.com/v1 # Optional api_key: '{{ANTHROPIC_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault + auth: null # When set to 'oauth', Loki will use OAuth instead of an API key; Authenticate with `loki --authenticate` # See https://docs.mistral.ai/ - type: openai-compatible diff --git a/docs/clients/CLIENTS.md b/docs/clients/CLIENTS.md index ab71cd0..2baac80 100644 --- a/docs/clients/CLIENTS.md +++ b/docs/clients/CLIENTS.md @@ -14,6 +14,7 @@ loki --info | grep 'config_file' | awk '{print $2}' - [Supported Clients](#supported-clients) - [Client Configuration](#client-configuration) +- [Authentication](#authentication) - [Extra Settings](#extra-settings) @@ -51,12 +52,13 @@ clients: The client metadata uniquely identifies the client in Loki so you can reference it across your configurations. The available settings are listed below: -| Setting | Description | -|----------|-----------------------------------------------------------------------------------------------| -| `name` | The name of the client (e.g. `openai`, `gemini`, etc.) | -| `models` | See the [model settings](#model-settings) documentation below | -| `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation | -| `extra` | See the [extra settings](#extra-settings) documentation below | +| Setting | Description | +|----------|------------------------------------------------------------------------------------------------------------| +| `name` | The name of the client (e.g. `openai`, `gemini`, etc.) | +| `auth` | Authentication method: `oauth` for OAuth, or omit to use `api_key` (see [Authentication](#authentication)) | +| `models` | See the [model settings](#model-settings) documentation below | +| `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation | +| `extra` | See the [extra settings](#extra-settings) documentation below | Be sure to also check provider-specific configurations for any extra fields that are added for authentication purposes. @@ -83,6 +85,73 @@ The `models` array lists the available models from the model client. Each one ha | `default_chunk_size` | | `embedding` | The default chunk size to use with the given model | | `max_batch_size` | | `embedding` | The maximum batch size that the given embedding model supports | +## Authentication + +Loki clients support two authentication methods: **API keys** and **OAuth**. Each client entry in your configuration +must use one or the other. + +### API Key Authentication + +Most clients authenticate using an API key. Simply set the `api_key` field directly or inject it from the +[Loki vault](../VAULT.md): + +```yaml +clients: + - type: claude + api_key: '{{ANTHROPIC_API_KEY}}' +``` + +API keys can also be provided via environment variables named `{CLIENT_NAME}_API_KEY` (e.g. `OPENAI_API_KEY`, +`GEMINI_API_KEY`). See the [environment variables documentation](../ENVIRONMENT-VARIABLES.md#client-related-variables) +for details. + +### OAuth Authentication + +For [providers that support OAuth](#providers-that-support-oauth), you can authenticate using your existing subscription instead of an API key. This uses +the OAuth 2.0 PKCE flow. + +**Step 1: Configure the client** + +Add a client entry with `auth: oauth` and no `api_key`: + +```yaml +clients: + - type: claude + name: my-claude-oauth + auth: oauth +``` + +**Step 2: Authenticate** + +Run the `--authenticate` flag with the client name: + +```sh +loki --authenticate my-claude-oauth +``` + +This opens your browser for the OAuth authorization flow. After authorizing, paste the authorization code back into +the terminal. Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes them when they expire. + +If you have only one OAuth-configured client, you can omit the name: + +```sh +loki --authenticate +``` + +**Step 3: Use normally** + +Once authenticated, the client works like any other. Loki uses the stored OAuth tokens automatically: + +```sh +loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!" +``` + +> **Note:** You can have multiple clients for the same provider. For example: you can have one with an API key and +> another with OAuth. Use the `name` field to distinguish them. + +### Providers That Support OAuth +* Claude + ## Extra Settings Loki also lets you customize some extra settings for interacting with APIs: diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 6b6e101..2f4f8fb 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -127,6 +127,9 @@ pub struct Cli { /// List all secrets stored in the Loki vault #[arg(long, exclusive = true)] pub list_secrets: bool, + /// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name) + #[arg(long, exclusive = true, value_name = "CLIENT_NAME")] + pub authenticate: Option>, /// Generate static shell completion scripts #[arg(long, value_name = "SHELL", value_enum)] pub completions: Option, diff --git a/src/client/claude.rs b/src/client/claude.rs index 7ea2e77..aa526a3 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -1,9 +1,12 @@ +use super::access_token::get_access_token; +use super::claude_oauth::ClaudeOAuthProvider; +use super::oauth::{self, OAuthProvider}; use super::*; use crate::utils::strip_think_tag; use anyhow::{Context, Result, bail}; -use reqwest::RequestBuilder; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{Value, json}; @@ -14,6 +17,7 @@ pub struct ClaudeConfig { pub name: Option, pub api_key: Option, pub api_base: Option, + pub auth: Option, #[serde(default)] pub models: Vec, pub patch: Option, @@ -27,22 +31,37 @@ impl ClaudeClient { pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; } -impl_client_trait!( - ClaudeClient, - ( - prepare_chat_completions, - claude_chat_completions, - claude_chat_completions_streaming - ), - (noop_prepare_embeddings, noop_embeddings), - (noop_prepare_rerank, noop_rerank), -); +#[async_trait::async_trait] +impl Client for ClaudeClient { + client_common_fns!(); -fn prepare_chat_completions( + async fn chat_completions_inner( + &self, + client: &ReqwestClient, + data: ChatCompletionsData, + ) -> Result { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + claude_chat_completions(builder, self.model()).await + } + + async fn chat_completions_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut SseHandler, + data: ChatCompletionsData, + ) -> Result<()> { + let request_data = prepare_chat_completions(self, client, data).await?; + let builder = self.request_builder(client, request_data); + claude_chat_completions_streaming(builder, handler, self.model()).await + } +} + +async fn prepare_chat_completions( self_: &ClaudeClient, + client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { - let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() .unwrap_or_else(|_| API_BASE.to_string()); @@ -53,7 +72,33 @@ fn prepare_chat_completions( let mut request_data = RequestData::new(url, body); request_data.header("anthropic-version", "2023-06-01"); - request_data.header("x-api-key", api_key); + + let uses_oauth = self_.config.auth.as_deref() == Some("oauth"); + + if uses_oauth { + let provider = ClaudeOAuthProvider; + let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?; + if !ready { + bail!( + "OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}", + self_.name(), + self_.name() + ); + } + let token = get_access_token(self_.name())?; + request_data.bearer_auth(token); + for (key, value) in provider.extra_request_headers() { + request_data.header(key, value); + } + } else if let Ok(api_key) = self_.get_api_key() { + request_data.header("x-api-key", api_key); + } else { + bail!( + "No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.", + self_.name(), + self_.name() + ); + } Ok(request_data) } diff --git a/src/client/claude_oauth.rs b/src/client/claude_oauth.rs new file mode 100644 index 0000000..fed4883 --- /dev/null +++ b/src/client/claude_oauth.rs @@ -0,0 +1,39 @@ +use super::oauth::OAuthProvider; + +pub const BETA_HEADER: &str = "oauth-2025-04-20"; + +pub struct ClaudeOAuthProvider; + +impl OAuthProvider for ClaudeOAuthProvider { + fn provider_name(&self) -> &str { + "claude" + } + + fn client_id(&self) -> &str { + "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + } + + fn authorize_url(&self) -> &str { + "https://claude.ai/oauth/authorize" + } + + fn token_url(&self) -> &str { + "https://console.anthropic.com/v1/oauth/token" + } + + fn redirect_uri(&self) -> &str { + "https://console.anthropic.com/oauth/code/callback" + } + + fn scopes(&self) -> &str { + "org:create_api_key user:profile user:inference" + } + + fn extra_token_headers(&self) -> Vec<(&str, &str)> { + vec![("anthropic-beta", BETA_HEADER)] + } + + fn extra_request_headers(&self) -> Vec<(&str, &str)> { + vec![("anthropic-beta", BETA_HEADER)] + } +} diff --git a/src/client/common.rs b/src/client/common.rs index dcc07d2..f5c8c2d 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -489,14 +489,6 @@ pub async fn call_chat_completions_streaming( } } -pub fn noop_prepare_embeddings(_client: &T, _data: &EmbeddingsData) -> Result { - bail!("The client doesn't support embeddings api") -} - -pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result { - bail!("The client doesn't support embeddings api") -} - pub fn noop_prepare_rerank(_client: &T, _data: &RerankData) -> Result { bail!("The client doesn't support rerank api") } diff --git a/src/client/mod.rs b/src/client/mod.rs index d0a4574..c1b4aa3 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,8 @@ mod access_token; +mod claude_oauth; mod common; mod message; +pub mod oauth; #[macro_use] mod macros; mod model; diff --git a/src/client/oauth.rs b/src/client/oauth.rs new file mode 100644 index 0000000..8ee117c --- /dev/null +++ b/src/client/oauth.rs @@ -0,0 +1,279 @@ +use super::ClientConfig; +use super::access_token::{is_valid_access_token, set_access_token}; +use crate::config::Config; +use anyhow::{Result, bail}; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use chrono::Utc; +use inquire::Text; +use reqwest::Client as ReqwestClient; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use sha2::{Digest, Sha256}; +use std::fs; +use uuid::Uuid; + +pub trait OAuthProvider: Send + Sync { + fn provider_name(&self) -> &str; + fn client_id(&self) -> &str; + fn authorize_url(&self) -> &str; + fn token_url(&self) -> &str; + fn redirect_uri(&self) -> &str; + fn scopes(&self) -> &str; + + fn extra_token_headers(&self) -> Vec<(&str, &str)> { + vec![] + } + + fn extra_request_headers(&self) -> Vec<(&str, &str)> { + vec![] + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OAuthTokens { + pub access_token: String, + pub refresh_token: String, + pub expires_at: i64, +} + +pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> { + let random_bytes: [u8; 32] = rand::random::<[u8; 32]>(); + let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes); + + let mut hasher = Sha256::new(); + hasher.update(code_verifier.as_bytes()); + let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); + + let state = Uuid::new_v4().to_string(); + + let encoded_scopes = urlencoding::encode(provider.scopes()); + let encoded_redirect = urlencoding::encode(provider.redirect_uri()); + + let authorize_url = format!( + "{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", + provider.authorize_url(), + provider.client_id(), + encoded_scopes, + encoded_redirect, + code_challenge, + state + ); + + println!( + "\nOpen this URL to authenticate with {} (client '{}'):\n", + provider.provider_name(), + client_name + ); + println!(" {authorize_url}\n"); + + let _ = open::that(&authorize_url); + + let input = Text::new("Paste the authorization code:").prompt()?; + + let parts: Vec<&str> = input.splitn(2, '#').collect(); + if parts.len() != 2 { + bail!("Invalid authorization code format. Expected format: #"); + } + let code = parts[0]; + let returned_state = parts[1]; + + if returned_state != state { + bail!( + "OAuth state mismatch: expected '{state}', got '{returned_state}'. \ + This may indicate a CSRF attack or a stale authorization attempt." + ); + } + + let client = ReqwestClient::new(); + let mut request = client.post(provider.token_url()).json(&json!({ + "grant_type": "authorization_code", + "client_id": provider.client_id(), + "code": code, + "code_verifier": code_verifier, + "redirect_uri": provider.redirect_uri(), + "state": state, + })); + + for (key, value) in provider.extra_token_headers() { + request = request.header(key, value); + } + + let response: Value = request.send().await?.json().await?; + + let access_token = response["access_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))? + .to_string(); + let refresh_token = response["refresh_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))? + .to_string(); + let expires_in = response["expires_in"] + .as_i64() + .ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?; + + let expires_at = Utc::now().timestamp() + expires_in; + + let tokens = OAuthTokens { + access_token, + refresh_token, + expires_at, + }; + + save_oauth_tokens(client_name, &tokens)?; + + println!( + "Successfully authenticated client '{}' with {} via OAuth. Tokens saved.", + client_name, + provider.provider_name() + ); + + Ok(()) +} + +pub fn load_oauth_tokens(client_name: &str) -> Option { + let path = Config::token_file(client_name); + let content = fs::read_to_string(path).ok()?; + serde_json::from_str(&content).ok() +} + +fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> { + let path = Config::token_file(client_name); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let json = serde_json::to_string_pretty(tokens)?; + fs::write(path, json)?; + Ok(()) +} + +pub async fn refresh_oauth_token( + client: &ReqwestClient, + provider: &dyn OAuthProvider, + client_name: &str, + tokens: &OAuthTokens, +) -> Result { + let mut request = client.post(provider.token_url()).json(&json!({ + "grant_type": "refresh_token", + "client_id": provider.client_id(), + "refresh_token": tokens.refresh_token, + })); + + for (key, value) in provider.extra_token_headers() { + request = request.header(key, value); + } + + let response: Value = request.send().await?.json().await?; + + let access_token = response["access_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))? + .to_string(); + let refresh_token = response["refresh_token"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))? + .to_string(); + let expires_in = response["expires_in"] + .as_i64() + .ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?; + + let expires_at = Utc::now().timestamp() + expires_in; + + let new_tokens = OAuthTokens { + access_token, + refresh_token, + expires_at, + }; + + save_oauth_tokens(client_name, &new_tokens)?; + + Ok(new_tokens) +} + +pub async fn prepare_oauth_access_token( + client: &ReqwestClient, + provider: &impl OAuthProvider, + client_name: &str, +) -> Result { + if is_valid_access_token(client_name) { + return Ok(true); + } + + let tokens = match load_oauth_tokens(client_name) { + Some(t) => t, + None => return Ok(false), + }; + + let tokens = if Utc::now().timestamp() >= tokens.expires_at { + refresh_oauth_token(client, provider, client_name, &tokens).await? + } else { + tokens + }; + + set_access_token(client_name, tokens.access_token.clone(), tokens.expires_at); + + Ok(true) +} + +pub fn get_oauth_provider(provider_type: &str) -> Option { + match provider_type { + "claude" => Some(super::claude_oauth::ClaudeOAuthProvider), + _ => None, + } +} + +pub fn resolve_provider_type(client_name: &str, clients: &[ClientConfig]) -> Option<&'static str> { + for client_config in clients { + let (config_name, provider_type, auth) = client_config_info(client_config); + if config_name == client_name { + if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() { + return Some(provider_type); + } + return None; + } + } + None +} + +pub fn list_oauth_capable_clients(clients: &[ClientConfig]) -> Vec { + clients + .iter() + .filter_map(|client_config| { + let (name, provider_type, auth) = client_config_info(client_config); + if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() { + Some(name.to_string()) + } else { + None + } + }) + .collect() +} + +fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Option<&str>) { + match client_config { + ClientConfig::ClaudeConfig(c) => ( + c.name.as_deref().unwrap_or("claude"), + "claude", + c.auth.as_deref(), + ), + ClientConfig::OpenAIConfig(c) => (c.name.as_deref().unwrap_or("openai"), "openai", None), + ClientConfig::OpenAICompatibleConfig(c) => ( + c.name.as_deref().unwrap_or("openai-compatible"), + "openai-compatible", + None, + ), + ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None), + ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None), + ClientConfig::AzureOpenAIConfig(c) => ( + c.name.as_deref().unwrap_or("azure-openai"), + "azure-openai", + None, + ), + ClientConfig::VertexAIConfig(c) => { + (c.name.as_deref().unwrap_or("vertexai"), "vertexai", None) + } + ClientConfig::BedrockConfig(c) => (c.name.as_deref().unwrap_or("bedrock"), "bedrock", None), + ClientConfig::Unknown => ("unknown", "unknown", None), + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 0cfe672..2121f5f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -428,6 +428,14 @@ impl Config { base_dir.join(env!("CARGO_CRATE_NAME")) } + pub fn oauth_tokens_path() -> PathBuf { + Self::cache_path().join("oauth") + } + + pub fn token_file(client_name: &str) -> PathBuf { + Self::oauth_tokens_path().join(format!("{client_name}_oauth_tokens.json")) + } + pub fn log_path() -> PathBuf { Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME"))) } diff --git a/src/main.rs b/src/main.rs index 03ea6a0..f788a90 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ mod vault; extern crate log; use crate::client::{ - ModelType, call_chat_completions, call_chat_completions_streaming, list_models, + ModelType, call_chat_completions, call_chat_completions_streaming, list_models, oauth, }; use crate::config::{ Agent, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, GlobalConfig, Input, SHELL_ROLE, @@ -29,15 +29,17 @@ use crate::utils::*; use crate::cli::Cli; use crate::vault::Vault; -use anyhow::{Result, bail}; +use anyhow::{Result, anyhow, bail}; use clap::{CommandFactory, Parser}; use clap_complete::CompleteEnv; -use inquire::Text; +use client::ClientConfig; +use inquire::{Select, Text}; use log::LevelFilter; use log4rs::append::console::ConsoleAppender; use log4rs::append::file::FileAppender; use log4rs::config::{Appender, Logger, Root}; use log4rs::encode::pattern::PatternEncoder; +use oauth::OAuthProvider; use parking_lot::RwLock; use std::path::PathBuf; use std::{env, mem, process, sync::Arc}; @@ -81,6 +83,13 @@ async fn main() -> Result<()> { let log_path = setup_logger()?; + if let Some(client_arg) = &cli.authenticate { + let config = Config::init_bare()?; + let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?; + oauth::run_oauth_flow(&provider, &client_name).await?; + return Ok(()); + } + if vault_flags { return Vault::handle_vault_flags(cli, Config::init_bare()?); } @@ -504,3 +513,33 @@ fn init_console_logger( .build(Root::builder().appender("console").build(root_log_level)) .unwrap() } + +fn resolve_oauth_client( + explicit: Option<&str>, + clients: &[ClientConfig], +) -> Result<(String, impl OAuthProvider)> { + if let Some(name) = explicit { + let provider_type = oauth::resolve_provider_type(name, clients) + .ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?; + let provider = oauth::get_oauth_provider(provider_type).unwrap(); + return Ok((name.to_string(), provider)); + } + + let candidates = oauth::list_oauth_capable_clients(clients); + match candidates.len() { + 0 => bail!("No OAuth-capable clients configured."), + 1 => { + let name = &candidates[0]; + let provider_type = oauth::resolve_provider_type(name, clients).unwrap(); + let provider = oauth::get_oauth_provider(provider_type).unwrap(); + Ok((name.clone(), provider)) + } + _ => { + let choice = + Select::new("Select a client to authenticate:", candidates.clone()).prompt()?; + let provider_type = oauth::resolve_provider_type(&choice, clients).unwrap(); + let provider = oauth::get_oauth_provider(provider_type).unwrap(); + Ok((choice, provider)) + } + } +}