feat: Support OAuth authentication flows for Claude

This commit is contained in:
2026-03-11 11:10:48 -06:00
parent 741b9c364c
commit 83f66e1061
13 changed files with 567 additions and 32 deletions
Generated
+37
View File
@@ -2869,6 +2869,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "is-docker"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3"
dependencies = [
"once_cell",
]
[[package]] [[package]]
name = "is-macro" name = "is-macro"
version = "0.3.7" version = "0.3.7"
@@ -2892,6 +2901,16 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "is-wsl"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5"
dependencies = [
"is-docker",
"once_cell",
]
[[package]] [[package]]
name = "is_executable" name = "is_executable"
version = "1.0.5" version = "1.0.5"
@@ -3193,6 +3212,7 @@ dependencies = [
"log4rs", "log4rs",
"nu-ansi-term", "nu-ansi-term",
"num_cpus", "num_cpus",
"open",
"os_info", "os_info",
"parking_lot", "parking_lot",
"path-absolutize", "path-absolutize",
@@ -3869,6 +3889,17 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "open"
version = "5.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43bb73a7fa3799b198970490a51174027ba0d4ec504b03cd08caf513d40024bc"
dependencies = [
"is-wsl",
"libc",
"pathdiff",
]
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.75" version = "0.10.75"
@@ -4039,6 +4070,12 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]] [[package]]
name = "pem" name = "pem"
version = "3.0.6" version = "3.0.6"
+2 -1
View File
@@ -96,6 +96,8 @@ colored = "3.0.0"
clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] } clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] }
gman = "0.3.0" gman = "0.3.0"
clap_complete_nushell = "4.5.9" clap_complete_nushell = "4.5.9"
open = "5"
rand = "0.9.0"
[dependencies.reqwest] [dependencies.reqwest]
version = "0.12.0" version = "0.12.0"
@@ -126,7 +128,6 @@ arboard = { version = "3.3.0", default-features = false }
[dev-dependencies] [dev-dependencies]
pretty_assertions = "1.4.0" pretty_assertions = "1.4.0"
rand = "0.9.0"
[[bin]] [[bin]]
name = "loki" name = "loki"
+20
View File
@@ -39,6 +39,7 @@ Coming from [AIChat](https://github.com/sigoden/aichat)? Follow the [migration g
* [Todo System](./docs/TODO-SYSTEM.md): Built-in task tracking for improved agent reliability with smaller models. * [Todo System](./docs/TODO-SYSTEM.md): Built-in task tracking for improved agent reliability with smaller models.
* [Environment Variables](./docs/ENVIRONMENT-VARIABLES.md): Override and customize your Loki configuration at runtime with environment variables. * [Environment Variables](./docs/ENVIRONMENT-VARIABLES.md): Override and customize your Loki configuration at runtime with environment variables.
* [Client Configurations](./docs/clients/CLIENTS.md): Configuration instructions for various LLM providers. * [Client Configurations](./docs/clients/CLIENTS.md): Configuration instructions for various LLM providers.
* [Authentication (API Key & OAuth)](./docs/clients/CLIENTS.md#authentication): Authenticate with API keys or OAuth for subscription-based access.
* [Patching API Requests](./docs/clients/PATCHES.md): Learn how to patch API requests for advanced customization. * [Patching API Requests](./docs/clients/PATCHES.md): Learn how to patch API requests for advanced customization.
* [Custom Themes](./docs/THEMES.md): Change the look and feel of Loki to your preferences with custom themes. * [Custom Themes](./docs/THEMES.md): Change the look and feel of Loki to your preferences with custom themes.
* [History](#history): A history of how Loki came to be. * [History](#history): A history of how Loki came to be.
@@ -150,6 +151,25 @@ guide you through the process when you first attempt to access the vault. So, to
loki --list-secrets loki --list-secrets
``` ```
### Authentication
Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key
(set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max
subscribers), you can authenticate with your existing subscription instead:
```yaml
# In your config.yaml
clients:
- type: claude
name: my-claude-oauth
auth: oauth # Indicate you want to authenticate with OAuth instead of an API key
```
```sh
loki --authenticate my-claude-oauth
```
For full details, see the [authentication documentation](./docs/clients/CLIENTS.md#authentication).
### Tab-Completions ### Tab-Completions
You can also enable tab completions to make using Loki easier. To do so, add the following to your shell profile: You can also enable tab completions to make using Loki easier. To do so, add the following to your shell profile:
```shell ```shell
+1
View File
@@ -210,6 +210,7 @@ clients:
- type: claude - type: claude
api_base: https://api.anthropic.com/v1 # Optional api_base: https://api.anthropic.com/v1 # Optional
api_key: '{{ANTHROPIC_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault api_key: '{{ANTHROPIC_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault
auth: null # When set to 'oauth', Loki will use OAuth instead of an API key; Authenticate with `loki --authenticate`
# See https://docs.mistral.ai/ # See https://docs.mistral.ai/
- type: openai-compatible - type: openai-compatible
+75 -6
View File
@@ -14,6 +14,7 @@ loki --info | grep 'config_file' | awk '{print $2}'
<!--toc:start--> <!--toc:start-->
- [Supported Clients](#supported-clients) - [Supported Clients](#supported-clients)
- [Client Configuration](#client-configuration) - [Client Configuration](#client-configuration)
- [Authentication](#authentication)
- [Extra Settings](#extra-settings) - [Extra Settings](#extra-settings)
<!--toc:end--> <!--toc:end-->
@@ -51,12 +52,13 @@ clients:
The client metadata uniquely identifies the client in Loki so you can reference it across your configurations. The The client metadata uniquely identifies the client in Loki so you can reference it across your configurations. The
available settings are listed below: available settings are listed below:
| Setting | Description | | Setting | Description |
|----------|-----------------------------------------------------------------------------------------------| |----------|------------------------------------------------------------------------------------------------------------|
| `name` | The name of the client (e.g. `openai`, `gemini`, etc.) | | `name` | The name of the client (e.g. `openai`, `gemini`, etc.) |
| `models` | See the [model settings](#model-settings) documentation below | | `auth` | Authentication method: `oauth` for OAuth, or omit to use `api_key` (see [Authentication](#authentication)) |
| `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation | | `models` | See the [model settings](#model-settings) documentation below |
| `extra` | See the [extra settings](#extra-settings) documentation below | | `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation |
| `extra` | See the [extra settings](#extra-settings) documentation below |
Be sure to also check provider-specific configurations for any extra fields that are added for authentication purposes. Be sure to also check provider-specific configurations for any extra fields that are added for authentication purposes.
@@ -83,6 +85,73 @@ The `models` array lists the available models from the model client. Each one ha
| `default_chunk_size` | | `embedding` | The default chunk size to use with the given model | | `default_chunk_size` | | `embedding` | The default chunk size to use with the given model |
| `max_batch_size` | | `embedding` | The maximum batch size that the given embedding model supports | | `max_batch_size` | | `embedding` | The maximum batch size that the given embedding model supports |
## Authentication
Loki clients support two authentication methods: **API keys** and **OAuth**. Each client entry in your configuration
must use one or the other.
### API Key Authentication
Most clients authenticate using an API key. Simply set the `api_key` field directly or inject it from the
[Loki vault](../VAULT.md):
```yaml
clients:
- type: claude
api_key: '{{ANTHROPIC_API_KEY}}'
```
API keys can also be provided via environment variables named `{CLIENT_NAME}_API_KEY` (e.g. `OPENAI_API_KEY`,
`GEMINI_API_KEY`). See the [environment variables documentation](../ENVIRONMENT-VARIABLES.md#client-related-variables)
for details.
### OAuth Authentication
For [providers that support OAuth](#providers-that-support-oauth), you can authenticate using your existing subscription instead of an API key. This uses
the OAuth 2.0 PKCE flow.
**Step 1: Configure the client**
Add a client entry with `auth: oauth` and no `api_key`:
```yaml
clients:
- type: claude
name: my-claude-oauth
auth: oauth
```
**Step 2: Authenticate**
Run the `--authenticate` flag with the client name:
```sh
loki --authenticate my-claude-oauth
```
This opens your browser for the OAuth authorization flow. After authorizing, paste the authorization code back into
the terminal. Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes them when they expire.
If you have only one OAuth-configured client, you can omit the name:
```sh
loki --authenticate
```
**Step 3: Use normally**
Once authenticated, the client works like any other. Loki uses the stored OAuth tokens automatically:
```sh
loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!"
```
> **Note:** You can have multiple clients for the same provider. For example: you can have one with an API key and
> another with OAuth. Use the `name` field to distinguish them.
### Providers That Support OAuth
* Claude
## Extra Settings ## Extra Settings
Loki also lets you customize some extra settings for interacting with APIs: Loki also lets you customize some extra settings for interacting with APIs:
+3
View File
@@ -127,6 +127,9 @@ pub struct Cli {
/// List all secrets stored in the Loki vault /// List all secrets stored in the Loki vault
#[arg(long, exclusive = true)] #[arg(long, exclusive = true)]
pub list_secrets: bool, pub list_secrets: bool,
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
pub authenticate: Option<Option<String>>,
/// Generate static shell completion scripts /// Generate static shell completion scripts
#[arg(long, value_name = "SHELL", value_enum)] #[arg(long, value_name = "SHELL", value_enum)]
pub completions: Option<ShellCompletion>, pub completions: Option<ShellCompletion>,
+59 -14
View File
@@ -1,9 +1,12 @@
use super::access_token::get_access_token;
use super::claude_oauth::ClaudeOAuthProvider;
use super::oauth::{self, OAuthProvider};
use super::*; use super::*;
use crate::utils::strip_think_tag; use crate::utils::strip_think_tag;
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder; use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value, json}; use serde_json::{Value, json};
@@ -14,6 +17,7 @@ pub struct ClaudeConfig {
pub name: Option<String>, pub name: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub api_base: Option<String>, pub api_base: Option<String>,
pub auth: Option<String>,
#[serde(default)] #[serde(default)]
pub models: Vec<ModelData>, pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>, pub patch: Option<RequestPatch>,
@@ -27,22 +31,37 @@ impl ClaudeClient {
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)]; pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
} }
impl_client_trait!( #[async_trait::async_trait]
ClaudeClient, impl Client for ClaudeClient {
( client_common_fns!();
prepare_chat_completions,
claude_chat_completions,
claude_chat_completions_streaming
),
(noop_prepare_embeddings, noop_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions( async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
claude_chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
claude_chat_completions_streaming(builder, handler, self.model()).await
}
}
async fn prepare_chat_completions(
self_: &ClaudeClient, self_: &ClaudeClient,
client: &ReqwestClient,
data: ChatCompletionsData, data: ChatCompletionsData,
) -> Result<RequestData> { ) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_ let api_base = self_
.get_api_base() .get_api_base()
.unwrap_or_else(|_| API_BASE.to_string()); .unwrap_or_else(|_| API_BASE.to_string());
@@ -53,7 +72,33 @@ fn prepare_chat_completions(
let mut request_data = RequestData::new(url, body); let mut request_data = RequestData::new(url, body);
request_data.header("anthropic-version", "2023-06-01"); request_data.header("anthropic-version", "2023-06-01");
request_data.header("x-api-key", api_key);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
if uses_oauth {
let provider = ClaudeOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
for (key, value) in provider.extra_request_headers() {
request_data.header(key, value);
}
} else if let Ok(api_key) = self_.get_api_key() {
request_data.header("x-api-key", api_key);
} else {
bail!(
"No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.",
self_.name(),
self_.name()
);
}
Ok(request_data) Ok(request_data)
} }
+39
View File
@@ -0,0 +1,39 @@
use super::oauth::OAuthProvider;
pub const BETA_HEADER: &str = "oauth-2025-04-20";
pub struct ClaudeOAuthProvider;
impl OAuthProvider for ClaudeOAuthProvider {
fn provider_name(&self) -> &str {
"claude"
}
fn client_id(&self) -> &str {
"9d1c250a-e61b-44d9-88ed-5944d1962f5e"
}
fn authorize_url(&self) -> &str {
"https://claude.ai/oauth/authorize"
}
fn token_url(&self) -> &str {
"https://console.anthropic.com/v1/oauth/token"
}
fn redirect_uri(&self) -> &str {
"https://console.anthropic.com/oauth/code/callback"
}
fn scopes(&self) -> &str {
"org:create_api_key user:profile user:inference"
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)]
}
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)]
}
}
-8
View File
@@ -489,14 +489,6 @@ pub async fn call_chat_completions_streaming(
} }
} }
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
bail!("The client doesn't support embeddings api")
}
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> { pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
bail!("The client doesn't support rerank api") bail!("The client doesn't support rerank api")
} }
+2
View File
@@ -1,6 +1,8 @@
mod access_token; mod access_token;
mod claude_oauth;
mod common; mod common;
mod message; mod message;
pub mod oauth;
#[macro_use] #[macro_use]
mod macros; mod macros;
mod model; mod model;
+279
View File
@@ -0,0 +1,279 @@
use super::ClientConfig;
use super::access_token::{is_valid_access_token, set_access_token};
use crate::config::Config;
use anyhow::{Result, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use inquire::Text;
use reqwest::Client as ReqwestClient;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sha2::{Digest, Sha256};
use std::fs;
use uuid::Uuid;
pub trait OAuthProvider: Send + Sync {
fn provider_name(&self) -> &str;
fn client_id(&self) -> &str;
fn authorize_url(&self) -> &str;
fn token_url(&self) -> &str;
fn redirect_uri(&self) -> &str;
fn scopes(&self) -> &str;
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![]
}
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
vec![]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthTokens {
pub access_token: String,
pub refresh_token: String,
pub expires_at: i64,
}
pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> {
let random_bytes: [u8; 32] = rand::random::<[u8; 32]>();
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
let state = Uuid::new_v4().to_string();
let encoded_scopes = urlencoding::encode(provider.scopes());
let encoded_redirect = urlencoding::encode(provider.redirect_uri());
let authorize_url = format!(
"{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
provider.authorize_url(),
provider.client_id(),
encoded_scopes,
encoded_redirect,
code_challenge,
state
);
println!(
"\nOpen this URL to authenticate with {} (client '{}'):\n",
provider.provider_name(),
client_name
);
println!(" {authorize_url}\n");
let _ = open::that(&authorize_url);
let input = Text::new("Paste the authorization code:").prompt()?;
let parts: Vec<&str> = input.splitn(2, '#').collect();
if parts.len() != 2 {
bail!("Invalid authorization code format. Expected format: <code>#<state>");
}
let code = parts[0];
let returned_state = parts[1];
if returned_state != state {
bail!(
"OAuth state mismatch: expected '{state}', got '{returned_state}'. \
This may indicate a CSRF attack or a stale authorization attempt."
);
}
let client = ReqwestClient::new();
let mut request = client.post(provider.token_url()).json(&json!({
"grant_type": "authorization_code",
"client_id": provider.client_id(),
"code": code,
"code_verifier": code_verifier,
"redirect_uri": provider.redirect_uri(),
"state": state,
}));
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))?
.to_string();
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
let tokens = OAuthTokens {
access_token,
refresh_token,
expires_at,
};
save_oauth_tokens(client_name, &tokens)?;
println!(
"Successfully authenticated client '{}' with {} via OAuth. Tokens saved.",
client_name,
provider.provider_name()
);
Ok(())
}
pub fn load_oauth_tokens(client_name: &str) -> Option<OAuthTokens> {
let path = Config::token_file(client_name);
let content = fs::read_to_string(path).ok()?;
serde_json::from_str(&content).ok()
}
fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
let path = Config::token_file(client_name);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(tokens)?;
fs::write(path, json)?;
Ok(())
}
pub async fn refresh_oauth_token(
client: &ReqwestClient,
provider: &dyn OAuthProvider,
client_name: &str,
tokens: &OAuthTokens,
) -> Result<OAuthTokens> {
let mut request = client.post(provider.token_url()).json(&json!({
"grant_type": "refresh_token",
"client_id": provider.client_id(),
"refresh_token": tokens.refresh_token,
}));
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))?
.to_string();
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
let new_tokens = OAuthTokens {
access_token,
refresh_token,
expires_at,
};
save_oauth_tokens(client_name, &new_tokens)?;
Ok(new_tokens)
}
pub async fn prepare_oauth_access_token(
client: &ReqwestClient,
provider: &impl OAuthProvider,
client_name: &str,
) -> Result<bool> {
if is_valid_access_token(client_name) {
return Ok(true);
}
let tokens = match load_oauth_tokens(client_name) {
Some(t) => t,
None => return Ok(false),
};
let tokens = if Utc::now().timestamp() >= tokens.expires_at {
refresh_oauth_token(client, provider, client_name, &tokens).await?
} else {
tokens
};
set_access_token(client_name, tokens.access_token.clone(), tokens.expires_at);
Ok(true)
}
pub fn get_oauth_provider(provider_type: &str) -> Option<impl OAuthProvider> {
match provider_type {
"claude" => Some(super::claude_oauth::ClaudeOAuthProvider),
_ => None,
}
}
pub fn resolve_provider_type(client_name: &str, clients: &[ClientConfig]) -> Option<&'static str> {
for client_config in clients {
let (config_name, provider_type, auth) = client_config_info(client_config);
if config_name == client_name {
if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() {
return Some(provider_type);
}
return None;
}
}
None
}
pub fn list_oauth_capable_clients(clients: &[ClientConfig]) -> Vec<String> {
clients
.iter()
.filter_map(|client_config| {
let (name, provider_type, auth) = client_config_info(client_config);
if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() {
Some(name.to_string())
} else {
None
}
})
.collect()
}
fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Option<&str>) {
match client_config {
ClientConfig::ClaudeConfig(c) => (
c.name.as_deref().unwrap_or("claude"),
"claude",
c.auth.as_deref(),
),
ClientConfig::OpenAIConfig(c) => (c.name.as_deref().unwrap_or("openai"), "openai", None),
ClientConfig::OpenAICompatibleConfig(c) => (
c.name.as_deref().unwrap_or("openai-compatible"),
"openai-compatible",
None,
),
ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None),
ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None),
ClientConfig::AzureOpenAIConfig(c) => (
c.name.as_deref().unwrap_or("azure-openai"),
"azure-openai",
None,
),
ClientConfig::VertexAIConfig(c) => {
(c.name.as_deref().unwrap_or("vertexai"), "vertexai", None)
}
ClientConfig::BedrockConfig(c) => (c.name.as_deref().unwrap_or("bedrock"), "bedrock", None),
ClientConfig::Unknown => ("unknown", "unknown", None),
}
}
+8
View File
@@ -428,6 +428,14 @@ impl Config {
base_dir.join(env!("CARGO_CRATE_NAME")) base_dir.join(env!("CARGO_CRATE_NAME"))
} }
pub fn oauth_tokens_path() -> PathBuf {
Self::cache_path().join("oauth")
}
pub fn token_file(client_name: &str) -> PathBuf {
Self::oauth_tokens_path().join(format!("{client_name}_oauth_tokens.json"))
}
pub fn log_path() -> PathBuf { pub fn log_path() -> PathBuf {
Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME"))) Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME")))
} }
+42 -3
View File
@@ -16,7 +16,7 @@ mod vault;
extern crate log; extern crate log;
use crate::client::{ use crate::client::{
ModelType, call_chat_completions, call_chat_completions_streaming, list_models, ModelType, call_chat_completions, call_chat_completions_streaming, list_models, oauth,
}; };
use crate::config::{ use crate::config::{
Agent, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, GlobalConfig, Input, SHELL_ROLE, Agent, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, GlobalConfig, Input, SHELL_ROLE,
@@ -29,15 +29,17 @@ use crate::utils::*;
use crate::cli::Cli; use crate::cli::Cli;
use crate::vault::Vault; use crate::vault::Vault;
use anyhow::{Result, bail}; use anyhow::{Result, anyhow, bail};
use clap::{CommandFactory, Parser}; use clap::{CommandFactory, Parser};
use clap_complete::CompleteEnv; use clap_complete::CompleteEnv;
use inquire::Text; use client::ClientConfig;
use inquire::{Select, Text};
use log::LevelFilter; use log::LevelFilter;
use log4rs::append::console::ConsoleAppender; use log4rs::append::console::ConsoleAppender;
use log4rs::append::file::FileAppender; use log4rs::append::file::FileAppender;
use log4rs::config::{Appender, Logger, Root}; use log4rs::config::{Appender, Logger, Root};
use log4rs::encode::pattern::PatternEncoder; use log4rs::encode::pattern::PatternEncoder;
use oauth::OAuthProvider;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::path::PathBuf; use std::path::PathBuf;
use std::{env, mem, process, sync::Arc}; use std::{env, mem, process, sync::Arc};
@@ -81,6 +83,13 @@ async fn main() -> Result<()> {
let log_path = setup_logger()?; let log_path = setup_logger()?;
if let Some(client_arg) = &cli.authenticate {
let config = Config::init_bare()?;
let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?;
oauth::run_oauth_flow(&provider, &client_name).await?;
return Ok(());
}
if vault_flags { if vault_flags {
return Vault::handle_vault_flags(cli, Config::init_bare()?); return Vault::handle_vault_flags(cli, Config::init_bare()?);
} }
@@ -504,3 +513,33 @@ fn init_console_logger(
.build(Root::builder().appender("console").build(root_log_level)) .build(Root::builder().appender("console").build(root_log_level))
.unwrap() .unwrap()
} }
fn resolve_oauth_client(
explicit: Option<&str>,
clients: &[ClientConfig],
) -> Result<(String, impl OAuthProvider)> {
if let Some(name) = explicit {
let provider_type = oauth::resolve_provider_type(name, clients)
.ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?;
let provider = oauth::get_oauth_provider(provider_type).unwrap();
return Ok((name.to_string(), provider));
}
let candidates = oauth::list_oauth_capable_clients(clients);
match candidates.len() {
0 => bail!("No OAuth-capable clients configured."),
1 => {
let name = &candidates[0];
let provider_type = oauth::resolve_provider_type(name, clients).unwrap();
let provider = oauth::get_oauth_provider(provider_type).unwrap();
Ok((name.clone(), provider))
}
_ => {
let choice =
Select::new("Select a client to authenticate:", candidates.clone()).prompt()?;
let provider_type = oauth::resolve_provider_type(&choice, clients).unwrap();
let provider = oauth::get_oauth_provider(provider_type).unwrap();
Ok((choice, provider))
}
}
}