8 Commits

17 changed files with 482 additions and 103 deletions
Generated
+1
View File
@@ -3243,6 +3243,7 @@ dependencies = [
"tokio-stream",
"unicode-segmentation",
"unicode-width 0.2.2",
"url",
"urlencoding",
"uuid",
"which",
+1
View File
@@ -98,6 +98,7 @@ gman = "0.3.0"
clap_complete_nushell = "4.5.9"
open = "5"
rand = "0.9.0"
url = "2.5.8"
[dependencies.reqwest]
version = "0.12.0"
+1 -1
View File
@@ -154,7 +154,7 @@ loki --list-secrets
### Authentication
Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key
(set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max
subscribers), you can authenticate with your existing subscription instead:
subscribers, Google Gemini), you can authenticate with your existing subscription instead:
```yaml
# In your config.yaml
+26 -6
View File
@@ -14,11 +14,28 @@ _project_dir() {
(cd "${dir}" 2>/dev/null && pwd) || echo "${dir}"
}
# Normalize a path to be relative to project root.
# Strips the project_dir prefix if the LLM passes an absolute path.
# Usage: local rel_path; rel_path=$(_normalize_path "/abs/or/rel/path")
_normalize_path() {
local input_path="$1"
local project_dir
project_dir=$(_project_dir)
if [[ "${input_path}" == /* ]]; then
input_path="${input_path#"${project_dir}"/}"
fi
input_path="${input_path#./}"
echo "${input_path}"
}
# @cmd Read a file's contents before modifying
# @option --path! Path to the file (relative to project root)
read_file() {
local file_path
# shellcheck disable=SC2154
local file_path="${argc_path}"
file_path=$(_normalize_path "${argc_path}")
local project_dir
project_dir=$(_project_dir)
local full_path="${project_dir}/${file_path}"
@@ -39,7 +56,8 @@ read_file() {
# @option --path! Path for the file (relative to project root)
# @option --content! Complete file contents to write
write_file() {
local file_path="${argc_path}"
local file_path
file_path=$(_normalize_path "${argc_path}")
# shellcheck disable=SC2154
local content="${argc_content}"
local project_dir
@@ -47,7 +65,7 @@ write_file() {
local full_path="${project_dir}/${file_path}"
mkdir -p "$(dirname "${full_path}")"
echo "${content}" > "${full_path}"
printf '%s' "${content}" > "${full_path}"
green "Wrote: ${file_path}" >> "$LLM_OUTPUT"
}
@@ -55,7 +73,8 @@ write_file() {
# @cmd Find files similar to a given path (for pattern matching)
# @option --path! Path to find similar files for
find_similar_files() {
local file_path="${argc_path}"
local file_path
file_path=$(_normalize_path "${argc_path}")
local project_dir
project_dir=$(_project_dir)
@@ -71,14 +90,14 @@ find_similar_files() {
! -name "$(basename "${file_path}")" \
! -name "*test*" \
! -name "*spec*" \
2>/dev/null | head -3)
2>/dev/null | sed "s|^${project_dir}/||" | head -3)
if [[ -z "${results}" ]]; then
results=$(find "${project_dir}/src" -type f -name "*.${ext}" \
! -name "*test*" \
! -name "*spec*" \
-not -path '*/target/*' \
2>/dev/null | head -3)
2>/dev/null | sed "s|^${project_dir}/||" | head -3)
fi
if [[ -n "${results}" ]]; then
@@ -186,6 +205,7 @@ search_code() {
grep -v '/target/' | \
grep -v '/node_modules/' | \
grep -v '/.git/' | \
sed "s|^${project_dir}/||" | \
head -20) || true
if [[ -n "${results}" ]]; then
+22 -4
View File
@@ -14,6 +14,21 @@ _project_dir() {
(cd "${dir}" 2>/dev/null && pwd) || echo "${dir}"
}
# Normalize a path to be relative to project root.
# Strips the project_dir prefix if the LLM passes an absolute path.
_normalize_path() {
local input_path="$1"
local project_dir
project_dir=$(_project_dir)
if [[ "${input_path}" == /* ]]; then
input_path="${input_path#"${project_dir}"/}"
fi
input_path="${input_path#./}"
echo "${input_path}"
}
# @cmd Get project structure and layout
get_structure() {
local project_dir
@@ -78,6 +93,7 @@ search_content() {
grep -v '/node_modules/' | \
grep -v '/.git/' | \
grep -v '/dist/' | \
sed "s|^${project_dir}/||" | \
head -30) || true
if [[ -n "${results}" ]]; then
@@ -91,8 +107,9 @@ search_content() {
# @option --path! Path to the file (relative to project root)
# @option --lines Maximum lines to read (default: 200)
read_file() {
local file_path
# shellcheck disable=SC2154
local file_path="${argc_path}"
file_path=$(_normalize_path "${argc_path}")
local max_lines="${argc_lines:-200}"
local project_dir
project_dir=$(_project_dir)
@@ -122,7 +139,8 @@ read_file() {
# @cmd Find similar files to a given file (for pattern matching)
# @option --path! Path to the reference file
find_similar() {
local file_path="${argc_path}"
local file_path
file_path=$(_normalize_path "${argc_path}")
local project_dir
project_dir=$(_project_dir)
@@ -138,7 +156,7 @@ find_similar() {
! -name "$(basename "${file_path}")" \
! -name "*test*" \
! -name "*spec*" \
2>/dev/null | head -5)
2>/dev/null | sed "s|^${project_dir}/||" | head -5)
if [[ -n "${results}" ]]; then
echo "${results}" >> "$LLM_OUTPUT"
@@ -147,7 +165,7 @@ find_similar() {
! -name "$(basename "${file_path}")" \
! -name "*test*" \
-not -path '*/target/*' \
2>/dev/null | head -5)
2>/dev/null | sed "s|^${project_dir}/||" | head -5)
if [[ -n "${results}" ]]; then
echo "${results}" >> "$LLM_OUTPUT"
else
+23 -4
View File
@@ -14,21 +14,38 @@ _project_dir() {
(cd "${dir}" 2>/dev/null && pwd) || echo "${dir}"
}
# Normalize a path to be relative to project root.
# Strips the project_dir prefix if the LLM passes an absolute path.
_normalize_path() {
local input_path="$1"
local project_dir
project_dir=$(_project_dir)
if [[ "${input_path}" == /* ]]; then
input_path="${input_path#"${project_dir}"/}"
fi
input_path="${input_path#./}"
echo "${input_path}"
}
# @cmd Read a file for analysis
# @option --path! Path to the file (relative to project root)
read_file() {
local project_dir
project_dir=$(_project_dir)
local file_path
# shellcheck disable=SC2154
local full_path="${project_dir}/${argc_path}"
file_path=$(_normalize_path "${argc_path}")
local full_path="${project_dir}/${file_path}"
if [[ ! -f "${full_path}" ]]; then
error "File not found: ${argc_path}" >> "$LLM_OUTPUT"
error "File not found: ${file_path}" >> "$LLM_OUTPUT"
return 1
fi
{
info "Reading: ${argc_path}"
info "Reading: ${file_path}"
echo ""
cat "${full_path}"
} >> "$LLM_OUTPUT"
@@ -80,6 +97,7 @@ search_code() {
grep -v '/target/' | \
grep -v '/node_modules/' | \
grep -v '/.git/' | \
sed "s|^${project_dir}/||" | \
head -30) || true
if [[ -n "${results}" ]]; then
@@ -113,7 +131,8 @@ analyze_with_command() {
# @cmd List directory contents
# @option --path Path to list (default: project root)
list_directory() {
local dir_path="${argc_path:-.}"
local dir_path
dir_path=$(_normalize_path "${argc_path:-.}")
local project_dir
project_dir=$(_project_dir)
local full_path="${project_dir}/${dir_path}"
+1 -1
View File
@@ -16,7 +16,7 @@
},
"atlassian": {
"command": "npx",
"args": ["-y", "mcp-remote@0.1.13", "https://mcp.atlassian.com/v1/sse"]
"args": ["-y", "mcp-remote@0.1.13", "https://mcp.atlassian.com/v1/mcp"]
},
"docker": {
"command": "uvx",
+2
View File
@@ -192,6 +192,8 @@ clients:
- type: gemini
api_base: https://generativelanguage.googleapis.com/v1beta
api_key: '{{GEMINI_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault
auth: null # When set to 'oauth', Loki will use OAuth instead of an API key
# Authenticate with `loki --authenticate` or `.authenticate` in the REPL
patch:
chat_completions:
'.*':
+24 -2
View File
@@ -137,8 +137,29 @@ loki --authenticate
Alternatively, you can use the REPL command `.authenticate`.
This opens your browser for the OAuth authorization flow. After authorizing, paste the authorization code back into
the terminal. Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes them when they expire.
This opens your browser for the OAuth authorization flow. Depending on the provider, Loki will either start a
temporary localhost server to capture the callback automatically (e.g. Gemini) or ask you to paste the authorization
code back into the terminal (e.g. Claude). Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes
them when they expire.
#### Gemini OAuth Note
Loki uses the following scopes for OAuth with Gemini:
* https://www.googleapis.com/auth/generative-language.peruserquota
* https://www.googleapis.com/auth/userinfo.email
* https://www.googleapis.com/auth/generative-language.retriever (Sensitive)
Since the `generative-language.retriever` scope is a sensitive scope, Google needs to verify Loki, which requires full
branding (logo, official website, privacy policy, terms of service, etc.). The Loki app is open-source and is designed
to be used as a simple CLI. As such, there's no terms of service or privacy policy associated with it, and thus Google
cannot verify Loki.
So, when you kick off OAuth with Gemini, you may see a page similar to the following:
![](../images/clients/gemini-oauth-page.png)
Simply click the `Advanced` link and click `Go to Loki (unsafe)` to continue the OAuth flow.
![](../images/clients/gemini-oauth-unverified.png)
![](../images/clients/gemini-oauth-unverified-allow.png)
**Step 3: Use normally**
@@ -153,6 +174,7 @@ loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!"
### Providers That Support OAuth
* Claude
* Gemini
## Extra Settings
Loki also lets you customize some extra settings for interacting with APIs:
+7
View File
@@ -3,6 +3,13 @@
# - https://platform.openai.com/docs/api-reference/chat
- provider: openai
models:
- name: gpt-5.2
max_input_tokens: 400000
max_output_tokens: 128000
input_price: 1.75
output_price: 14
supports_vision: true
supports_function_calling: true
- name: gpt-5.1
max_input_tokens: 400000
max_output_tokens: 128000
+4
View File
@@ -29,6 +29,10 @@ impl OAuthProvider for ClaudeOAuthProvider {
"org:create_api_key user:profile user:inference"
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("code", "true")]
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)]
}
+119 -34
View File
@@ -1,8 +1,11 @@
use super::access_token::get_access_token;
use super::gemini_oauth::GeminiOAuthProvider;
use super::oauth;
use super::vertexai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use anyhow::{Context, Result, bail};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{Value, json};
@@ -13,6 +16,7 @@ pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
pub auth: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
@@ -23,25 +27,64 @@ impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
create_client_config!([("api_key", "API Key", None, true)]);
create_oauth_supported_client_config!();
}
impl_client_trait!(
GeminiClient,
(
prepare_chat_completions,
gemini_chat_completions,
gemini_chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);
#[async_trait::async_trait]
impl Client for GeminiClient {
client_common_fns!();
fn prepare_chat_completions(
fn supports_oauth(&self) -> bool {
self.config.auth.as_deref() == Some("oauth")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let request_data = prepare_embeddings(self, client, data).await?;
let builder = self.request_builder(client, request_data);
embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &ReqwestClient,
data: &RerankData,
) -> Result<RerankOutput> {
let request_data = noop_prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
noop_rerank(builder, self.model()).await
}
}
async fn prepare_chat_completions(
self_: &GeminiClient,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
@@ -59,26 +102,61 @@ fn prepare_chat_completions(
);
let body = gemini_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("x-goog-api-key", api_key);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
} else if let Ok(api_key) = self_.get_api_key() {
request_data.header("x-goog-api-key", api_key);
} else {
bail!(
"No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.",
self_.name(),
self_.name()
);
}
Ok(request_data)
}
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
async fn prepare_embeddings(
self_: &GeminiClient,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
let url = if uses_oauth {
format!(
"{}/models/{}:batchEmbedContents",
api_base.trim_end_matches('/'),
self_.model.real_name(),
)
} else {
let api_key = self_.get_api_key()?;
format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
)
};
let model_id = format!("models/{}", self_.model.real_name());
@@ -89,21 +167,28 @@ fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<Req
json!({
"model": model_id,
"content": {
"parts": [
{
"text": text
}
]
"parts": [{ "text": text }]
},
})
})
.collect();
let body = json!({
"requests": requests,
});
let body = json!({ "requests": requests });
let mut request_data = RequestData::new(url, body);
let request_data = RequestData::new(url, body);
if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
}
Ok(request_data)
}
+49
View File
@@ -0,0 +1,49 @@
use super::oauth::{OAuthProvider, TokenRequestFormat};
pub struct GeminiOAuthProvider;
const GEMINI_CLIENT_ID: &str =
"50826443741-upqcebrs4gctqht1f08ku46qlbirkdsj.apps.googleusercontent.com";
const GEMINI_CLIENT_SECRET: &str = "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";
impl OAuthProvider for GeminiOAuthProvider {
fn provider_name(&self) -> &str {
"gemini"
}
fn client_id(&self) -> &str {
GEMINI_CLIENT_ID
}
fn authorize_url(&self) -> &str {
"https://accounts.google.com/o/oauth2/v2/auth"
}
fn token_url(&self) -> &str {
"https://oauth2.googleapis.com/token"
}
fn redirect_uri(&self) -> &str {
""
}
fn scopes(&self) -> &str {
"https://www.googleapis.com/auth/generative-language.peruserquota https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/userinfo.email"
}
fn client_secret(&self) -> Option<&str> {
Some(GEMINI_CLIENT_SECRET)
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("access_type", "offline"), ("prompt", "consent")]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::FormUrlEncoded
}
fn uses_localhost_redirect(&self) -> bool {
true
}
}
+1
View File
@@ -1,6 +1,7 @@
mod access_token;
mod claude_oauth;
mod common;
mod gemini_oauth;
mod message;
pub mod oauth;
#[macro_use]
+198 -48
View File
@@ -1,18 +1,27 @@
use super::ClientConfig;
use super::access_token::{is_valid_access_token, set_access_token};
use crate::config::Config;
use anyhow::{Result, bail};
use anyhow::{Result, anyhow, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use inquire::Text;
use reqwest::Client as ReqwestClient;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use url::Url;
use uuid::Uuid;
pub enum TokenRequestFormat {
Json,
FormUrlEncoded,
}
pub trait OAuthProvider: Send + Sync {
fn provider_name(&self) -> &str;
fn client_id(&self) -> &str;
@@ -21,6 +30,22 @@ pub trait OAuthProvider: Send + Sync {
fn redirect_uri(&self) -> &str;
fn scopes(&self) -> &str;
fn client_secret(&self) -> Option<&str> {
None
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::Json
}
fn uses_localhost_redirect(&self) -> bool {
false
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![]
}
@@ -37,7 +62,7 @@ pub struct OAuthTokens {
pub expires_at: i64,
}
pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) -> Result<()> {
pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> {
let random_bytes: [u8; 32] = rand::random::<[u8; 32]>();
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
@@ -47,11 +72,21 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let state = Uuid::new_v4().to_string();
let encoded_scopes = urlencoding::encode(provider.scopes());
let encoded_redirect = urlencoding::encode(provider.redirect_uri());
let redirect_uri = if provider.uses_localhost_redirect() {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
let uri = format!("http://127.0.0.1:{port}/callback");
drop(listener);
uri
} else {
provider.redirect_uri().to_string()
};
let authorize_url = format!(
"{}?code=true&client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
let encoded_scopes = urlencoding::encode(provider.scopes());
let encoded_redirect = urlencoding::encode(&redirect_uri);
let mut authorize_url = format!(
"{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
provider.authorize_url(),
provider.client_id(),
encoded_scopes,
@@ -60,6 +95,14 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
state
);
for (key, value) in provider.extra_authorize_params() {
authorize_url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
println!(
"\nOpen this URL to authenticate with {} (client '{}'):\n",
provider.provider_name(),
@@ -69,14 +112,16 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
let _ = open::that(&authorize_url);
let input = Text::new("Paste the authorization code:").prompt()?;
let parts: Vec<&str> = input.splitn(2, '#').collect();
if parts.len() != 2 {
bail!("Invalid authorization code format. Expected format: <code>#<state>");
}
let code = parts[0];
let returned_state = parts[1];
let (code, returned_state) = if provider.uses_localhost_redirect() {
listen_for_oauth_callback(&redirect_uri)?
} else {
let input = Text::new("Paste the authorization code:").prompt()?;
let parts: Vec<&str> = input.splitn(2, '#').collect();
if parts.len() != 2 {
bail!("Invalid authorization code format. Expected format: <code>#<state>");
}
(parts[0].to_string(), parts[1].to_string())
};
if returned_state != state {
bail!(
@@ -86,32 +131,32 @@ pub async fn run_oauth_flow(provider: &impl OAuthProvider, client_name: &str) ->
}
let client = ReqwestClient::new();
let mut request = client.post(provider.token_url()).json(&json!({
"grant_type": "authorization_code",
"client_id": provider.client_id(),
"code": code,
"code_verifier": code_verifier,
"redirect_uri": provider.redirect_uri(),
"state": state,
}));
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
let request = build_token_request(
&client,
provider,
&[
("grant_type", "authorization_code"),
("client_id", provider.client_id()),
("code", &code),
("code_verifier", &code_verifier),
("redirect_uri", &redirect_uri),
("state", &state),
],
);
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))?
.ok_or_else(|| anyhow!("Missing access_token in response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))?
.ok_or_else(|| anyhow!("Missing refresh_token in response: {response}"))?
.to_string();
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?;
.ok_or_else(|| anyhow!("Missing expires_in in response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
@@ -150,33 +195,33 @@ fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
pub async fn refresh_oauth_token(
client: &ReqwestClient,
provider: &dyn OAuthProvider,
provider: &impl OAuthProvider,
client_name: &str,
tokens: &OAuthTokens,
) -> Result<OAuthTokens> {
let mut request = client.post(provider.token_url()).json(&json!({
"grant_type": "refresh_token",
"client_id": provider.client_id(),
"refresh_token": tokens.refresh_token,
}));
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
let request = build_token_request(
client,
provider,
&[
("grant_type", "refresh_token"),
("client_id", provider.client_id()),
("refresh_token", &tokens.refresh_token),
],
);
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))?
.ok_or_else(|| anyhow!("Missing access_token in refresh response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in refresh response: {response}"))?
.to_string();
.map(|s| s.to_string())
.unwrap_or_else(|| tokens.refresh_token.clone());
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?;
.ok_or_else(|| anyhow!("Missing expires_in in refresh response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
@@ -216,9 +261,110 @@ pub async fn prepare_oauth_access_token(
Ok(true)
}
pub fn get_oauth_provider(provider_type: &str) -> Option<impl OAuthProvider> {
fn build_token_request(
client: &ReqwestClient,
provider: &(impl OAuthProvider + ?Sized),
params: &[(&str, &str)],
) -> RequestBuilder {
let mut request = match provider.token_request_format() {
TokenRequestFormat::Json => {
let body: serde_json::Map<String, Value> = params
.iter()
.map(|(k, v)| (k.to_string(), Value::String(v.to_string())))
.collect();
if let Some(secret) = provider.client_secret() {
let mut body = body;
body.insert(
"client_secret".to_string(),
Value::String(secret.to_string()),
);
client.post(provider.token_url()).json(&body)
} else {
client.post(provider.token_url()).json(&body)
}
}
TokenRequestFormat::FormUrlEncoded => {
let mut form: HashMap<String, String> = params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
if let Some(secret) = provider.client_secret() {
form.insert("client_secret".to_string(), secret.to_string());
}
client.post(provider.token_url()).form(&form)
}
};
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
request
}
fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> {
let url: Url = redirect_uri.parse()?;
let host = url.host_str().unwrap_or("127.0.0.1");
let port = url
.port()
.ok_or_else(|| anyhow!("No port in redirect URI"))?;
let path = url.path();
println!("Waiting for OAuth callback on {redirect_uri} ...\n");
let listener = TcpListener::bind(format!("{host}:{port}"))?;
let (mut stream, _) = listener.accept()?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let request_path = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow!("Malformed HTTP request from OAuth callback"))?;
let full_url = format!("http://{host}:{port}{request_path}");
let parsed: Url = full_url.parse()?;
if !parsed.path().starts_with(path) {
bail!("Unexpected callback path: {}", parsed.path());
}
let code = parsed
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| {
let error = parsed
.query_pairs()
.find(|(k, _)| k == "error")
.map(|(_, v)| v.to_string())
.unwrap_or_else(|| "unknown".to_string());
anyhow!("OAuth callback returned error: {error}")
})?;
let returned_state = parsed
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow!("Missing state parameter in OAuth callback"))?;
let response_body = "<html><body><h2>Authentication successful!</h2><p>You can close this tab and return to your terminal.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes())?;
Ok((code, returned_state))
}
pub fn get_oauth_provider(provider_type: &str) -> Option<Box<dyn OAuthProvider>> {
match provider_type {
"claude" => Some(super::claude_oauth::ClaudeOAuthProvider),
"claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)),
"gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)),
_ => None,
}
}
@@ -263,7 +409,11 @@ fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Opti
"openai-compatible",
None,
),
ClientConfig::GeminiConfig(c) => (c.name.as_deref().unwrap_or("gemini"), "gemini", None),
ClientConfig::GeminiConfig(c) => (
c.name.as_deref().unwrap_or("gemini"),
"gemini",
c.auth.as_deref(),
),
ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None),
ClientConfig::AzureOpenAIConfig(c) => (
c.name.as_deref().unwrap_or("azure-openai"),
+2 -2
View File
@@ -86,7 +86,7 @@ async fn main() -> Result<()> {
if let Some(client_arg) = &cli.authenticate {
let config = Config::init_bare()?;
let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?;
oauth::run_oauth_flow(&provider, &client_name).await?;
oauth::run_oauth_flow(&*provider, &client_name).await?;
return Ok(());
}
@@ -517,7 +517,7 @@ fn init_console_logger(
fn resolve_oauth_client(
explicit: Option<&str>,
clients: &[ClientConfig],
) -> Result<(String, impl OAuthProvider)> {
) -> Result<(String, Box<dyn OAuthProvider>)> {
if let Some(name) = explicit {
let provider_type = oauth::resolve_provider_type(name, clients)
.ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?;
+1 -1
View File
@@ -438,7 +438,7 @@ pub async fn run_repl_command(
}
let clients = config.read().clients.clone();
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
oauth::run_oauth_flow(&provider, &client_name).await?;
oauth::run_oauth_flow(&*provider, &client_name).await?;
}
".prompt" => match args {
Some(text) => {