1 Commits

12 changed files with 379 additions and 81 deletions
Generated
+1
View File
@@ -3243,6 +3243,7 @@ dependencies = [
"tokio-stream", "tokio-stream",
"unicode-segmentation", "unicode-segmentation",
"unicode-width 0.2.2", "unicode-width 0.2.2",
"url",
"urlencoding", "urlencoding",
"uuid", "uuid",
"which", "which",
+1
View File
@@ -98,6 +98,7 @@ gman = "0.3.0"
clap_complete_nushell = "4.5.9" clap_complete_nushell = "4.5.9"
open = "5" open = "5"
rand = "0.9.0" rand = "0.9.0"
url = "2.5.8"
[dependencies.reqwest] [dependencies.reqwest]
version = "0.12.0" version = "0.12.0"
+1 -1
View File
@@ -154,7 +154,7 @@ loki --list-secrets
### Authentication ### Authentication
Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key
(set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max (set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max
subscribers), you can authenticate with your existing subscription instead: subscribers, Google Gemini), you can authenticate with your existing subscription instead:
```yaml ```yaml
# In your config.yaml # In your config.yaml
+2
View File
@@ -192,6 +192,8 @@ clients:
- type: gemini - type: gemini
api_base: https://generativelanguage.googleapis.com/v1beta api_base: https://generativelanguage.googleapis.com/v1beta
api_key: '{{GEMINI_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault api_key: '{{GEMINI_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault
auth: null # When set to 'oauth', Loki will use OAuth instead of an API key
# Authenticate with `loki --authenticate` or `.authenticate` in the REPL
patch: patch:
chat_completions: chat_completions:
'.*': '.*':
+5 -2
View File
@@ -137,8 +137,10 @@ loki --authenticate
Alternatively, you can use the REPL command `.authenticate`. Alternatively, you can use the REPL command `.authenticate`.
This opens your browser for the OAuth authorization flow. After authorizing, paste the authorization code back into This opens your browser for the OAuth authorization flow. Depending on the provider, Loki will either start a
the terminal. Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes them when they expire. temporary localhost server to capture the callback automatically (e.g. Gemini) or ask you to paste the authorization
code back into the terminal (e.g. Claude). Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes
them when they expire.
**Step 3: Use normally** **Step 3: Use normally**
@@ -153,6 +155,7 @@ loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!"
### Providers That Support OAuth ### Providers That Support OAuth
* Claude * Claude
* Gemini
## Extra Settings ## Extra Settings
Loki also lets you customize some extra settings for interacting with APIs: Loki also lets you customize some extra settings for interacting with APIs:
+4
View File
@@ -29,6 +29,10 @@ impl OAuthProvider for ClaudeOAuthProvider {
"org:create_api_key user:profile user:inference" "org:create_api_key user:profile user:inference"
} }
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("code", "true")]
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> { fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)] vec![("anthropic-beta", BETA_HEADER)]
} }
+119 -34
View File
@@ -1,8 +1,11 @@
use super::access_token::get_access_token;
use super::gemini_oauth::GeminiOAuthProvider;
use super::oauth;
use super::vertexai::*; use super::vertexai::*;
use super::*; use super::*;
use anyhow::{Context, Result}; use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder; use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value, json}; use serde_json::{Value, json};
@@ -13,6 +16,7 @@ pub struct GeminiConfig {
pub name: Option<String>, pub name: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub api_base: Option<String>, pub api_base: Option<String>,
pub auth: Option<String>,
#[serde(default)] #[serde(default)]
pub models: Vec<ModelData>, pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>, pub patch: Option<RequestPatch>,
@@ -23,25 +27,64 @@ impl GeminiClient {
config_get_fn!(api_key, get_api_key); config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base); config_get_fn!(api_base, get_api_base);
create_client_config!([("api_key", "API Key", None, true)]); create_oauth_supported_client_config!();
} }
impl_client_trait!( #[async_trait::async_trait]
GeminiClient, impl Client for GeminiClient {
( client_common_fns!();
prepare_chat_completions,
gemini_chat_completions,
gemini_chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions( fn supports_oauth(&self) -> bool {
self.config.auth.as_deref() == Some("oauth")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let request_data = prepare_embeddings(self, client, data).await?;
let builder = self.request_builder(client, request_data);
embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &ReqwestClient,
data: &RerankData,
) -> Result<RerankOutput> {
let request_data = noop_prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
noop_rerank(builder, self.model()).await
}
}
async fn prepare_chat_completions(
self_: &GeminiClient, self_: &GeminiClient,
client: &ReqwestClient,
data: ChatCompletionsData, data: ChatCompletionsData,
) -> Result<RequestData> { ) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_ let api_base = self_
.get_api_base() .get_api_base()
.unwrap_or_else(|_| API_BASE.to_string()); .unwrap_or_else(|_| API_BASE.to_string());
@@ -59,26 +102,61 @@ fn prepare_chat_completions(
); );
let body = gemini_build_chat_completions_body(data, &self_.model)?; let body = gemini_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body); let mut request_data = RequestData::new(url, body);
request_data.header("x-goog-api-key", api_key); let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
} else if let Ok(api_key) = self_.get_api_key() {
request_data.header("x-goog-api-key", api_key);
} else {
bail!(
"No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.",
self_.name(),
self_.name()
);
}
Ok(request_data) Ok(request_data)
} }
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> { async fn prepare_embeddings(
let api_key = self_.get_api_key()?; self_: &GeminiClient,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_base = self_ let api_base = self_
.get_api_base() .get_api_base()
.unwrap_or_else(|_| API_BASE.to_string()); .unwrap_or_else(|_| API_BASE.to_string());
let url = format!( let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'), let url = if uses_oauth {
self_.model.real_name(), format!(
api_key "{}/models/{}:batchEmbedContents",
); api_base.trim_end_matches('/'),
self_.model.real_name(),
)
} else {
let api_key = self_.get_api_key()?;
format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
)
};
let model_id = format!("models/{}", self_.model.real_name()); let model_id = format!("models/{}", self_.model.real_name());
@@ -89,21 +167,28 @@ fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<Req
json!({ json!({
"model": model_id, "model": model_id,
"content": { "content": {
"parts": [ "parts": [{ "text": text }]
{
"text": text
}
]
}, },
}) })
}) })
.collect(); .collect();
let body = json!({ let body = json!({ "requests": requests });
"requests": requests, let mut request_data = RequestData::new(url, body);
});
let request_data = RequestData::new(url, body); if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
}
Ok(request_data) Ok(request_data)
} }
+50
View File
@@ -0,0 +1,50 @@
use super::oauth::{OAuthProvider, TokenRequestFormat};
pub struct GeminiOAuthProvider;
// TODO: Replace with real credentials after registering Loki with Google Cloud Console
const GEMINI_CLIENT_ID: &str =
"50826443741-upqcebrs4gctqht1f08ku46qlbirkdsj.apps.googleusercontent.com";
const GEMINI_CLIENT_SECRET: &str = "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";
impl OAuthProvider for GeminiOAuthProvider {
fn provider_name(&self) -> &str {
"gemini"
}
fn client_id(&self) -> &str {
GEMINI_CLIENT_ID
}
fn authorize_url(&self) -> &str {
"https://accounts.google.com/o/oauth2/v2/auth"
}
fn token_url(&self) -> &str {
"https://oauth2.googleapis.com/token"
}
fn redirect_uri(&self) -> &str {
""
}
fn scopes(&self) -> &str {
"https://www.googleapis.com/auth/cloud-platform.readonly https://www.googleapis.com/auth/userinfo.email"
}
fn client_secret(&self) -> Option<&str> {
Some(GEMINI_CLIENT_SECRET)
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("access_type", "offline"), ("prompt", "consent")]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::FormUrlEncoded
}
fn uses_localhost_redirect(&self) -> bool {
true
}
}
+1
View File
@@ -1,6 +1,7 @@
mod access_token; mod access_token;
mod claude_oauth; mod claude_oauth;
mod common; mod common;
mod gemini_oauth;
mod message; mod message;
pub mod oauth; pub mod oauth;
#[macro_use] #[macro_use]
+192 -41
View File
@@ -8,11 +8,20 @@ use chrono::Utc;
use inquire::Text; use inquire::Text;
use reqwest::Client as ReqwestClient; use reqwest::Client as ReqwestClient;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Value, json}; use serde_json::Value;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs; use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use url::Url;
use uuid::Uuid; use uuid::Uuid;
pub enum TokenRequestFormat {
Json,
FormUrlEncoded,
}
pub trait OAuthProvider: Send + Sync { pub trait OAuthProvider: Send + Sync {
fn provider_name(&self) -> &str; fn provider_name(&self) -> &str;
fn client_id(&self) -> &str; fn client_id(&self) -> &str;
@@ -21,6 +30,22 @@ pub trait OAuthProvider: Send + Sync {
fn redirect_uri(&self) -> &str; fn redirect_uri(&self) -> &str;
fn scopes(&self) -> &str; fn scopes(&self) -> &str;
fn client_secret(&self) -> Option<&str> {
None
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::Json
}
fn uses_localhost_redirect(&self) -> bool {
false
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> { fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![] vec![]
} }
@@ -37,7 +62,7 @@ pub struct OAuthTokens {
pub expires_at: i64, pub expires_at: i64,
} }
pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> { pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> {
let random_bytes: [u8; 32] = rand::random::<[u8; 32]>(); let random_bytes: [u8; 32] = rand::random::<[u8; 32]>();
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes); let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
@@ -47,11 +72,22 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let state = Uuid::new_v4().to_string(); let state = Uuid::new_v4().to_string();
let encoded_scopes = urlencoding::encode(provider.scopes()); let redirect_uri = if provider.uses_localhost_redirect() {
let encoded_redirect = urlencoding::encode(provider.redirect_uri()); let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
let uri = format!("http://127.0.0.1:{port}/callback");
// Drop the listener so run_oauth_flow can re-bind below
drop(listener);
uri
} else {
provider.redirect_uri().to_string()
};
let authorize_url = format!( let encoded_scopes = urlencoding::encode(provider.scopes());
"{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}", let encoded_redirect = urlencoding::encode(&redirect_uri);
let mut authorize_url = format!(
"{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
provider.authorize_url(), provider.authorize_url(),
provider.client_id(), provider.client_id(),
encoded_scopes, encoded_scopes,
@@ -60,6 +96,14 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
state state
); );
for (key, value) in provider.extra_authorize_params() {
authorize_url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
println!( println!(
"\nOpen this URL to authenticate with {} (client '{}'):\n", "\nOpen this URL to authenticate with {} (client '{}'):\n",
provider.provider_name(), provider.provider_name(),
@@ -69,14 +113,16 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let _ = open::that(&authorize_url); let _ = open::that(&authorize_url);
let input = Text::new("Paste the authorization code:").prompt()?; let (code, returned_state) = if provider.uses_localhost_redirect() {
listen_for_oauth_callback(&redirect_uri)?
let parts: Vec<&str> = input.splitn(2, '#').collect(); } else {
if parts.len() != 2 { let input = Text::new("Paste the authorization code:").prompt()?;
bail!("Invalid authorization code format. Expected format: <code>#<state>"); let parts: Vec<&str> = input.splitn(2, '#').collect();
} if parts.len() != 2 {
let code = parts[0]; bail!("Invalid authorization code format. Expected format: <code>#<state>");
let returned_state = parts[1]; }
(parts[0].to_string(), parts[1].to_string())
};
if returned_state != state { if returned_state != state {
bail!( bail!(
@@ -86,18 +132,18 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
} }
let client = ReqwestClient::new(); let client = ReqwestClient::new();
let mut request = client.post(provider.token_url()).json(&json!({ let request = build_token_request(
"grant_type": "authorization_code", &client,
"client_id": provider.client_id(), provider,
"code": code, &[
"code_verifier": code_verifier, ("grant_type", "authorization_code"),
"redirect_uri": provider.redirect_uri(), ("client_id", provider.client_id()),
"state": state, ("code", &code),
})); ("code_verifier", &code_verifier),
("redirect_uri", &redirect_uri),
for (key, value) in provider.extra_token_headers() { ("state", &state),
request = request.header(key, value); ],
} );
let response: Value = request.send().await?.json().await?; let response: Value = request.send().await?.json().await?;
@@ -150,19 +196,19 @@ fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
pub async fn refresh_oauth_token( pub async fn refresh_oauth_token(
client: &ReqwestClient, client: &ReqwestClient,
provider: &dyn OAuthProvider, provider: &impl OAuthProvider,
client_name: &str, client_name: &str,
tokens: &OAuthTokens, tokens: &OAuthTokens,
) -> Result<OAuthTokens> { ) -> Result<OAuthTokens> {
let mut request = client.post(provider.token_url()).json(&json!({ let request = build_token_request(
"grant_type": "refresh_token", client,
"client_id": provider.client_id(), provider,
"refresh_token": tokens.refresh_token, &[
})); ("grant_type", "refresh_token"),
("client_id", provider.client_id()),
for (key, value) in provider.extra_token_headers() { ("refresh_token", &tokens.refresh_token),
request = request.header(key, value); ],
} );
let response: Value = request.send().await?.json().await?; let response: Value = request.send().await?.json().await?;
@@ -172,8 +218,8 @@ pub async fn refresh_oauth_token(
.to_string(); .to_string();
let refresh_token = response["refresh_token"] let refresh_token = response["refresh_token"]
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))? .map(|s| s.to_string())
.to_string(); .unwrap_or_else(|| tokens.refresh_token.clone());
let expires_in = response["expires_in"] let expires_in = response["expires_in"]
.as_i64() .as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?; .ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?;
@@ -216,9 +262,110 @@ pub async fn prepare_oauth_access_token(
Ok(true) Ok(true)
} }
pub fn get_oauth_provider(provider_type: &str) -> Option<impl OAuthProvider> { fn build_token_request(
client: &ReqwestClient,
provider: &(impl OAuthProvider + ?Sized),
params: &[(&str, &str)],
) -> reqwest::RequestBuilder {
let mut request = match provider.token_request_format() {
TokenRequestFormat::Json => {
let body: serde_json::Map<String, Value> = params
.iter()
.map(|(k, v)| (k.to_string(), Value::String(v.to_string())))
.collect();
if let Some(secret) = provider.client_secret() {
let mut body = body;
body.insert(
"client_secret".to_string(),
Value::String(secret.to_string()),
);
client.post(provider.token_url()).json(&body)
} else {
client.post(provider.token_url()).json(&body)
}
}
TokenRequestFormat::FormUrlEncoded => {
let mut form: HashMap<String, String> = params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
if let Some(secret) = provider.client_secret() {
form.insert("client_secret".to_string(), secret.to_string());
}
client.post(provider.token_url()).form(&form)
}
};
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
request
}
fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> {
let url: Url = redirect_uri.parse()?;
let host = url.host_str().unwrap_or("127.0.0.1");
let port = url
.port()
.ok_or_else(|| anyhow::anyhow!("No port in redirect URI"))?;
let path = url.path();
println!("Waiting for OAuth callback on {redirect_uri} ...\n");
let listener = TcpListener::bind(format!("{host}:{port}"))?;
let (mut stream, _) = listener.accept()?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let request_path = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow::anyhow!("Malformed HTTP request from OAuth callback"))?;
let full_url = format!("http://{host}:{port}{request_path}");
let parsed: Url = full_url.parse()?;
let response_body = "<html><body><h2>Authentication successful!</h2><p>You can close this tab and return to your terminal.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes())?;
if !parsed.path().starts_with(path) {
bail!("Unexpected callback path: {}", parsed.path());
}
let code = parsed
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| {
let error = parsed
.query_pairs()
.find(|(k, _)| k == "error")
.map(|(_, v)| v.to_string())
.unwrap_or_else(|| "unknown".to_string());
anyhow::anyhow!("OAuth callback returned error: {error}")
})?;
let returned_state = parsed
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("Missing state parameter in OAuth callback"))?;
Ok((code, returned_state))
}
pub fn get_oauth_provider(provider_type: &str) -> Option<Box<dyn OAuthProvider>> {
match provider_type { match provider_type {
"claude" => Some(super::claude_oauth::ClaudeOAuthProvider), "claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)),
"gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)),
_ => None, _ => None,
} }
} }
@@ -263,7 +410,11 @@ fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Opti
"openai-compatible", "openai-compatible",
None, None,
), ),
ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None), ClientConfig::GeminiConfig(c) => (
c.name.as_deref().unwrap_or("gemini"),
"gemini",
c.auth.as_deref(),
),
ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None), ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None),
ClientConfig::AzureOpenAIConfig(c) => ( ClientConfig::AzureOpenAIConfig(c) => (
c.name.as_deref().unwrap_or("azure-openai"), c.name.as_deref().unwrap_or("azure-openai"),
+2 -2
View File
@@ -86,7 +86,7 @@ async fn main() -> Result<()> {
if let Some(client_arg) = &cli.authenticate { if let Some(client_arg) = &cli.authenticate {
let config = Config::init_bare()?; let config = Config::init_bare()?;
let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?; let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?;
oauth::run_oauth_flow(&provider, &client_name).await?; oauth::run_oauth_flow(&*provider, &client_name).await?;
return Ok(()); return Ok(());
} }
@@ -517,7 +517,7 @@ fn init_console_logger(
fn resolve_oauth_client( fn resolve_oauth_client(
explicit: Option<&str>, explicit: Option<&str>,
clients: &[ClientConfig], clients: &[ClientConfig],
) -> Result<(String, impl OAuthProvider)> { ) -> Result<(String, Box<dyn OAuthProvider>)> {
if let Some(name) = explicit { if let Some(name) = explicit {
let provider_type = oauth::resolve_provider_type(name, clients) let provider_type = oauth::resolve_provider_type(name, clients)
.ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?; .ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?;
+1 -1
View File
@@ -437,7 +437,7 @@ pub async fn run_repl_command(
} }
let clients = config.read().clients.clone(); let clients = config.read().clients.clone();
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?; let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
oauth::run_oauth_flow(&provider, &client_name).await?; oauth::run_oauth_flow(&*provider, &client_name).await?;
} }
".prompt" => match args { ".prompt" => match args {
Some(text) => { Some(text) => {