12 Commits

Author SHA1 Message Date
83581d9d18 feat: Initial scaffolding work for Gemini OAuth support (Claude generated, Alex updated) 2026-03-11 14:18:55 -06:00
eb4d1c02f4 feat: Support authenticating or refreshing OAuth for supported clients from within the REPL
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-03-11 13:07:27 -06:00
c428990900 fix: the updated regex for secrets injection broke MCP server secrets interpolation because the regex greedily matched on new lines, replacing too much content. This fix just ignores commented out lines in YAML files by skipping commented out lines. 2026-03-11 12:55:28 -06:00
03b9cc70b9 feat: Allow first-runs to select OAuth for supported providers 2026-03-11 12:01:17 -06:00
3fa0eb832c fix: Don't try to inject secrets into commented-out lines in the config 2026-03-11 11:11:09 -06:00
83f66e1061 feat: Support OAuth authentication flows for Claude 2026-03-11 11:10:48 -06:00
741b9c364c chore: Added support for Claude 4.6 gen models
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-03-10 14:55:30 -06:00
b6f6f456db fix: Removed top_p parameter from some agents so they can work across model providers
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-03-10 10:18:38 -06:00
00a6cf74d7 Merge branch 'main' of github.com:Dark-Alex-17/loki
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-03-09 14:58:23 -06:00
d35ca352ca chore: Added the new gemini-3.1-pro-preview model to gemini and vertex models
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-03-09 14:57:39 -06:00
57dc1cb252 docs: created an authorship policy and PR template that requires disclosure of AI assistance in contributions 2026-02-24 17:46:07 -07:00
101a9cdd6e style: Applied formatting to MCP module
CI / All (macos-latest) (push) Has been cancelled
CI / All (ubuntu-latest) (push) Has been cancelled
CI / All (windows-latest) (push) Has been cancelled
2026-02-20 15:28:21 -07:00
36 changed files with 1146 additions and 144 deletions
@@ -0,0 +1,11 @@
### AI assistance (if any):
- List tools here and files touched by them
### Authorship & Understanding
- [ ] I wrote or heavily modified this code myself
- [ ] I understand how it works end-to-end
- [ ] I can maintain this code in the future
- [ ] No undisclosed AI-generated code was used
- [ ] If AI assistance was used, it is documented below
+7
View File
@@ -76,6 +76,13 @@ Then, you can run workflows locally without having to commit and see if the GitH
act -W .github/workflows/release.yml --input_type bump=minor
```
## Authorship Policy
All code in this repository is written and reviewed by humans. AI-generated code (e.g., Copilot, ChatGPT,
Claude, etc.) is not permitted unless explicitly disclosed and approved.
Submissions must certify that the contributor understands and can maintain the code they submit.
## Questions? Reach out to me!
If you encounter any questions while developing Loki, please don't hesitate to reach out to me at
alex.j.tusa@gmail.com. I'm happy to help contributors in any way I can, regardless of if they're new or experienced!
Generated
+38
View File
@@ -2869,6 +2869,15 @@ dependencies = [
"serde",
]
[[package]]
name = "is-docker"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3"
dependencies = [
"once_cell",
]
[[package]]
name = "is-macro"
version = "0.3.7"
@@ -2892,6 +2901,16 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "is-wsl"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5"
dependencies = [
"is-docker",
"once_cell",
]
[[package]]
name = "is_executable"
version = "1.0.5"
@@ -3193,6 +3212,7 @@ dependencies = [
"log4rs",
"nu-ansi-term",
"num_cpus",
"open",
"os_info",
"parking_lot",
"path-absolutize",
@@ -3223,6 +3243,7 @@ dependencies = [
"tokio-stream",
"unicode-segmentation",
"unicode-width 0.2.2",
"url",
"urlencoding",
"uuid",
"which",
@@ -3869,6 +3890,17 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "open"
version = "5.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43bb73a7fa3799b198970490a51174027ba0d4ec504b03cd08caf513d40024bc"
dependencies = [
"is-wsl",
"libc",
"pathdiff",
]
[[package]]
name = "openssl"
version = "0.10.75"
@@ -4039,6 +4071,12 @@ dependencies = [
"once_cell",
]
[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem"
version = "3.0.6"
+3 -1
View File
@@ -96,6 +96,9 @@ colored = "3.0.0"
clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] }
gman = "0.3.0"
clap_complete_nushell = "4.5.9"
open = "5"
rand = "0.9.0"
url = "2.5.8"
[dependencies.reqwest]
version = "0.12.0"
@@ -126,7 +129,6 @@ arboard = { version = "3.3.0", default-features = false }
[dev-dependencies]
pretty_assertions = "1.4.0"
rand = "0.9.0"
[[bin]]
name = "loki"
+21
View File
@@ -39,6 +39,7 @@ Coming from [AIChat](https://github.com/sigoden/aichat)? Follow the [migration g
* [Todo System](./docs/TODO-SYSTEM.md): Built-in task tracking for improved agent reliability with smaller models.
* [Environment Variables](./docs/ENVIRONMENT-VARIABLES.md): Override and customize your Loki configuration at runtime with environment variables.
* [Client Configurations](./docs/clients/CLIENTS.md): Configuration instructions for various LLM providers.
* [Authentication (API Key & OAuth)](./docs/clients/CLIENTS.md#authentication): Authenticate with API keys or OAuth for subscription-based access.
* [Patching API Requests](./docs/clients/PATCHES.md): Learn how to patch API requests for advanced customization.
* [Custom Themes](./docs/THEMES.md): Change the look and feel of Loki to your preferences with custom themes.
* [History](#history): A history of how Loki came to be.
@@ -150,6 +151,26 @@ guide you through the process when you first attempt to access the vault. So, to
loki --list-secrets
```
### Authentication
Each client in your configuration needs authentication (with a few exceptions; e.g. ollama). Most clients use an API key
(set via `api_key` in the config or through the [vault](./docs/VAULT.md)). For providers that support OAuth (e.g. Claude Pro/Max
subscribers, Google Gemini), you can authenticate with your existing subscription instead:
```yaml
# In your config.yaml
clients:
- type: claude
name: my-claude-oauth
auth: oauth # Indicate you want to authenticate with OAuth instead of an API key
```
```sh
loki --authenticate my-claude-oauth
# Or via the REPL: .authenticate
```
For full details, see the [authentication documentation](./docs/clients/CLIENTS.md#authentication).
### Tab-Completions
You can also enable tab completions to make using Loki easier. To do so, add the following to your shell profile:
```shell
-1
View File
@@ -2,7 +2,6 @@ name: code-reviewer
description: CodeRabbit-style code reviewer - spawns per-file reviewers, synthesizes findings
version: 1.0.0
temperature: 0.1
top_p: 0.95
auto_continue: true
max_auto_continues: 20
-1
View File
@@ -2,7 +2,6 @@ name: coder
description: Implementation agent - writes code, follows patterns, verifies with builds
version: 1.0.0
temperature: 0.1
top_p: 0.95
auto_continue: true
max_auto_continues: 15
-1
View File
@@ -2,7 +2,6 @@ name: explore
description: Fast codebase exploration agent - finds patterns, structures, and relevant files
version: 1.0.0
temperature: 0.1
top_p: 0.95
variables:
- name: project_dir
-1
View File
@@ -2,7 +2,6 @@ name: file-reviewer
description: Reviews a single file's diff for bugs, style issues, and cross-cutting concerns
version: 1.0.0
temperature: 0.1
top_p: 0.95
variables:
- name: project_dir
-1
View File
@@ -2,7 +2,6 @@ name: oracle
description: High-IQ advisor for architecture, debugging, and complex decisions
version: 1.0.0
temperature: 0.2
top_p: 0.95
variables:
- name: project_dir
-1
View File
@@ -2,7 +2,6 @@ name: sisyphus
description: OpenCode-style orchestrator - classifies intent, delegates to specialists, tracks progress with todos
version: 2.0.0
temperature: 0.1
top_p: 0.95
agent_session: temp
auto_continue: true
+4
View File
@@ -192,6 +192,8 @@ clients:
- type: gemini
api_base: https://generativelanguage.googleapis.com/v1beta
api_key: '{{GEMINI_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault
auth: null # When set to 'oauth', Loki will use OAuth instead of an API key
# Authenticate with `loki --authenticate` or `.authenticate` in the REPL
patch:
chat_completions:
'.*':
@@ -210,6 +212,8 @@ clients:
- type: claude
api_base: https://api.anthropic.com/v1 # Optional
api_key: '{{ANTHROPIC_API_KEY}}' # You can either hard-code or inject secrets from the Loki vault
auth: null # When set to 'oauth', Loki will use OAuth instead of an API key
# Authenticate with `loki --authenticate` or `.authenticate` in the REPL
# See https://docs.mistral.ai/
- type: openai-compatible
+6
View File
@@ -23,6 +23,7 @@ You can enter the REPL by simply typing `loki` without any follow-up flags or ar
- [`.edit` - Modify configuration files](#edit---modify-configuration-files)
- [`.delete` - Delete configurations from Loki](#delete---delete-configurations-from-loki)
- [`.info` - Display information about the current mode](#info---display-information-about-the-current-mode)
- [`.authenticate` - Authenticate the current model client via OAuth](#authenticate---authenticate-the-current-model-client-via-oauth)
- [`.exit` - Exit an agent/role/session/rag or the Loki REPL itself](#exit---exit-an-agentrolesessionrag-or-the-loki-repl-itself)
- [`.help` - Show the help guide](#help---show-the-help-guide)
<!--toc:end-->
@@ -237,6 +238,11 @@ The following entities are supported:
| `.info agent` | Display information about the active agent |
| `.info rag` | Display information about the active RAG |
### `.authenticate` - Authenticate the current model client via OAuth
The `.authenticate` command will start the OAuth flow for the current model client if
* The client supports OAuth (See the [clients documentation](./clients/CLIENTS.md#providers-that-support-oauth) for supported clients)
* The client is configured in your Loki configuration to use OAuth via the `auth: oauth` property
### `.exit` - Exit an agent/role/session/rag or the Loki REPL itself
The `.exit` command is used to move between modes in the Loki REPL.
+80 -6
View File
@@ -14,6 +14,7 @@ loki --info | grep 'config_file' | awk '{print $2}'
<!--toc:start-->
- [Supported Clients](#supported-clients)
- [Client Configuration](#client-configuration)
- [Authentication](#authentication)
- [Extra Settings](#extra-settings)
<!--toc:end-->
@@ -51,12 +52,13 @@ clients:
The client metadata uniquely identifies the client in Loki so you can reference it across your configurations. The
available settings are listed below:
| Setting | Description |
|----------|-----------------------------------------------------------------------------------------------|
| `name` | The name of the client (e.g. `openai`, `gemini`, etc.) |
| `models` | See the [model settings](#model-settings) documentation below |
| `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation |
| `extra` | See the [extra settings](#extra-settings) documentation below |
| Setting | Description |
|----------|------------------------------------------------------------------------------------------------------------|
| `name` | The name of the client (e.g. `openai`, `gemini`, etc.) |
| `auth` | Authentication method: `oauth` for OAuth, or omit to use `api_key` (see [Authentication](#authentication)) |
| `models` | See the [model settings](#model-settings) documentation below |
| `patch` | See the [client patch configuration](./PATCHES.md#client-configuration-patches) documentation |
| `extra` | See the [extra settings](#extra-settings) documentation below |
Be sure to also check provider-specific configurations for any extra fields that are added for authentication purposes.
@@ -83,6 +85,78 @@ The `models` array lists the available models from the model client. Each one ha
| `default_chunk_size` | | `embedding` | The default chunk size to use with the given model |
| `max_batch_size` | | `embedding` | The maximum batch size that the given embedding model supports |
## Authentication
Loki clients support two authentication methods: **API keys** and **OAuth**. Each client entry in your configuration
must use one or the other.
### API Key Authentication
Most clients authenticate using an API key. Simply set the `api_key` field directly or inject it from the
[Loki vault](../VAULT.md):
```yaml
clients:
- type: claude
api_key: '{{ANTHROPIC_API_KEY}}'
```
API keys can also be provided via environment variables named `{CLIENT_NAME}_API_KEY` (e.g. `OPENAI_API_KEY`,
`GEMINI_API_KEY`). See the [environment variables documentation](../ENVIRONMENT-VARIABLES.md#client-related-variables)
for details.
### OAuth Authentication
For [providers that support OAuth](#providers-that-support-oauth), you can authenticate using your existing subscription instead of an API key. This uses
the OAuth 2.0 PKCE flow.
**Step 1: Configure the client**
Add a client entry with `auth: oauth` and no `api_key`:
```yaml
clients:
- type: claude
name: my-claude-oauth
auth: oauth
```
**Step 2: Authenticate**
Run the `--authenticate` flag with the client name:
```sh
loki --authenticate my-claude-oauth
```
Or if you have only one OAuth-configured client, you can omit the name:
```sh
loki --authenticate
```
Alternatively, you can use the REPL command `.authenticate`.
This opens your browser for the OAuth authorization flow. Depending on the provider, Loki will either start a
temporary localhost server to capture the callback automatically (e.g. Gemini) or ask you to paste the authorization
code back into the terminal (e.g. Claude). Loki stores the tokens in `~/.cache/loki/oauth` and automatically refreshes
them when they expire.
**Step 3: Use normally**
Once authenticated, the client works like any other. Loki uses the stored OAuth tokens automatically:
```sh
loki -m my-claude-oauth:claude-sonnet-4-20250514 "Hello!"
```
> **Note:** You can have multiple clients for the same provider. For example: you can have one with an API key and
> another with OAuth. Use the `name` field to distinguish them.
### Providers That Support OAuth
* Claude
* Gemini
## Extra Settings
Loki also lets you customize some extra settings for interacting with APIs:
+58 -5
View File
@@ -195,6 +195,13 @@
# - https://ai.google.dev/api/rest/v1beta/models/streamGenerateContent
- provider: gemini
models:
- name: gemini-3.1-pro-preview
max_input_tokens: 1048576
max_output_tokens: 65535
input_price: 0.3
output_price: 2.5
supports_vision: true
supports_function_calling: true
- name: gemini-2.5-flash
max_input_tokens: 1048576
max_output_tokens: 65535
@@ -248,6 +255,54 @@
# - https://docs.anthropic.com/en/api/messages
- provider: claude
models:
- name: claude-opus-4-6
max_input_tokens: 200000
max_output_tokens: 8192
require_max_tokens: true
input_price: 5
output_price: 25
supports_vision: true
supports_function_calling: true
- name: claude-opus-4-6:thinking
real_name: claude-opus-4-6
max_input_tokens: 200000
max_output_tokens: 24000
require_max_tokens: true
input_price: 5
output_price: 25
supports_vision: true
supports_function_calling: true
patch:
body:
temperature: null
top_p: null
thinking:
type: enabled
budget_tokens: 16000
- name: claude-sonnet-4-6
max_input_tokens: 200000
max_output_tokens: 8192
require_max_tokens: true
input_price: 3
output_price: 15
supports_vision: true
supports_function_calling: true
- name: claude-sonnet-4-6:thinking
real_name: claude-sonnet-4-6
max_input_tokens: 200000
max_output_tokens: 24000
require_max_tokens: true
input_price: 3
output_price: 15
supports_vision: true
supports_function_calling: true
patch:
body:
temperature: null
top_p: null
thinking:
type: enabled
budget_tokens: 16000
- name: claude-sonnet-4-5-20250929
max_input_tokens: 200000
max_output_tokens: 8192
@@ -670,8 +725,7 @@
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
- provider: vertexai
models:
- name: gemini-3-pro-preview
hipaa_safe: true
- name: gemini-3.1-pro-preview
max_input_tokens: 1048576
max_output_tokens: 65536
input_price: 2
@@ -1198,7 +1252,6 @@
max_input_tokens: 1024
input_price: 0.07
# Links:
# - https://help.aliyun.com/zh/model-studio/getting-started/models
# - https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api
@@ -1881,7 +1934,7 @@
input_price: 0.3
output_price: 1.5
supports_function_calling: true
- name: qwen/qwen3-coder # Qwen3 Coder 480B A35B
- name: qwen/qwen3-coder # Qwen3 Coder 480B A35B
max_input_tokens: 262144
input_price: 0.22
output_price: 0.95
@@ -2361,4 +2414,4 @@
- name: rerank-2-lite
type: reranker
max_input_tokens: 8000
input_price: 0.02
input_price: 0.02
+3
View File
@@ -127,6 +127,9 @@ pub struct Cli {
/// List all secrets stored in the Loki vault
#[arg(long, exclusive = true)]
pub list_secrets: bool,
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
pub authenticate: Option<Option<String>>,
/// Generate static shell completion scripts
#[arg(long, value_name = "SHELL", value_enum)]
pub completions: Option<ShellCompletion>,
+4 -4
View File
@@ -18,16 +18,16 @@ pub struct AzureOpenAIConfig {
impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 2] = [
create_client_config!([
(
"api_base",
"API Base",
Some("e.g. https://{RESOURCE}.openai.azure.com"),
false
false,
),
("api_key", "API Key", None, true),
];
]);
}
impl_client_trait!(
+2 -2
View File
@@ -32,11 +32,11 @@ impl BedrockClient {
config_get_fn!(region, get_region);
config_get_fn!(session_token, get_session_token);
pub const PROMPTS: [PromptAction<'static>; 3] = [
create_client_config!([
("access_key_id", "AWS Access Key ID", None, true),
("secret_access_key", "AWS Secret Access Key", None, true),
("region", "AWS Region", None, false),
];
]);
fn chat_completions_builder(
&self,
+64 -15
View File
@@ -1,9 +1,12 @@
use super::access_token::get_access_token;
use super::claude_oauth::ClaudeOAuthProvider;
use super::oauth::{self, OAuthProvider};
use super::*;
use crate::utils::strip_think_tag;
use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{Value, json};
@@ -14,6 +17,7 @@ pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
pub auth: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
@@ -24,25 +28,44 @@ impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
create_oauth_supported_client_config!();
}
impl_client_trait!(
ClaudeClient,
(
prepare_chat_completions,
claude_chat_completions,
claude_chat_completions_streaming
),
(noop_prepare_embeddings, noop_embeddings),
(noop_prepare_rerank, noop_rerank),
);
#[async_trait::async_trait]
impl Client for ClaudeClient {
client_common_fns!();
fn prepare_chat_completions(
fn supports_oauth(&self) -> bool {
self.config.auth.as_deref() == Some("oauth")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
claude_chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
claude_chat_completions_streaming(builder, handler, self.model()).await
}
}
async fn prepare_chat_completions(
self_: &ClaudeClient,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
@@ -53,7 +76,33 @@ fn prepare_chat_completions(
let mut request_data = RequestData::new(url, body);
request_data.header("anthropic-version", "2023-06-01");
request_data.header("x-api-key", api_key);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
if uses_oauth {
let provider = ClaudeOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
for (key, value) in provider.extra_request_headers() {
request_data.header(key, value);
}
} else if let Ok(api_key) = self_.get_api_key() {
request_data.header("x-api-key", api_key);
} else {
bail!(
"No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.",
self_.name(),
self_.name()
);
}
Ok(request_data)
}
+43
View File
@@ -0,0 +1,43 @@
use super::oauth::OAuthProvider;
pub const BETA_HEADER: &str = "oauth-2025-04-20";
pub struct ClaudeOAuthProvider;
impl OAuthProvider for ClaudeOAuthProvider {
fn provider_name(&self) -> &str {
"claude"
}
fn client_id(&self) -> &str {
"9d1c250a-e61b-44d9-88ed-5944d1962f5e"
}
fn authorize_url(&self) -> &str {
"https://claude.ai/oauth/authorize"
}
fn token_url(&self) -> &str {
"https://console.anthropic.com/v1/oauth/token"
}
fn redirect_uri(&self) -> &str {
"https://console.anthropic.com/oauth/code/callback"
}
fn scopes(&self) -> &str {
"org:create_api_key user:profile user:inference"
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("code", "true")]
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)]
}
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
vec![("anthropic-beta", BETA_HEADER)]
}
}
+1 -1
View File
@@ -24,7 +24,7 @@ impl CohereClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
create_client_config!([("api_key", "API Key", None, true)]);
}
impl_client_trait!(
+5 -9
View File
@@ -47,6 +47,10 @@ pub trait Client: Sync + Send {
fn model(&self) -> &Model;
fn supports_oauth(&self) -> bool {
false
}
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let extra = self.extra_config();
@@ -489,14 +493,6 @@ pub async fn call_chat_completions_streaming(
}
}
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
bail!("The client doesn't support embeddings api")
}
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
bail!("The client doesn't support rerank api")
}
@@ -554,7 +550,7 @@ pub fn json_str_from_map<'a>(
map.get(field_name).and_then(|v| v.as_str())
}
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
pub async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
let models: Vec<String> = provider
.models
+120 -35
View File
@@ -1,10 +1,13 @@
use super::access_token::get_access_token;
use super::gemini_oauth::GeminiOAuthProvider;
use super::oauth;
use super::vertexai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use anyhow::{Context, Result, bail};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
@@ -13,6 +16,7 @@ pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
pub auth: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
@@ -23,25 +27,64 @@ impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
create_oauth_supported_client_config!();
}
impl_client_trait!(
GeminiClient,
(
prepare_chat_completions,
gemini_chat_completions,
gemini_chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);
#[async_trait::async_trait]
impl Client for GeminiClient {
client_common_fns!();
fn prepare_chat_completions(
fn supports_oauth(&self) -> bool {
self.config.auth.as_deref() == Some("oauth")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, client, data).await?;
let builder = self.request_builder(client, request_data);
gemini_chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let request_data = prepare_embeddings(self, client, data).await?;
let builder = self.request_builder(client, request_data);
embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &ReqwestClient,
data: &RerankData,
) -> Result<RerankOutput> {
let request_data = noop_prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
noop_rerank(builder, self.model()).await
}
}
async fn prepare_chat_completions(
self_: &GeminiClient,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
@@ -59,26 +102,61 @@ fn prepare_chat_completions(
);
let body = gemini_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("x-goog-api-key", api_key);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
} else if let Ok(api_key) = self_.get_api_key() {
request_data.header("x-goog-api-key", api_key);
} else {
bail!(
"No authentication configured for '{}'. Set `api_key` or use `auth: oauth` with `loki --authenticate {}`.",
self_.name(),
self_.name()
);
}
Ok(request_data)
}
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
async fn prepare_embeddings(
self_: &GeminiClient,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
);
let uses_oauth = self_.config.auth.as_deref() == Some("oauth");
let url = if uses_oauth {
format!(
"{}/models/{}:batchEmbedContents",
api_base.trim_end_matches('/'),
self_.model.real_name(),
)
} else {
let api_key = self_.get_api_key()?;
format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
)
};
let model_id = format!("models/{}", self_.model.real_name());
@@ -89,21 +167,28 @@ fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<Req
json!({
"model": model_id,
"content": {
"parts": [
{
"text": text
}
]
"parts": [{ "text": text }]
},
})
})
.collect();
let body = json!({
"requests": requests,
});
let body = json!({ "requests": requests });
let mut request_data = RequestData::new(url, body);
let request_data = RequestData::new(url, body);
if uses_oauth {
let provider = GeminiOAuthProvider;
let ready = oauth::prepare_oauth_access_token(client, &provider, self_.name()).await?;
if !ready {
bail!(
"OAuth configured but no tokens found for '{}'. Run: loki --authenticate {}",
self_.name(),
self_.name()
);
}
let token = get_access_token(self_.name())?;
request_data.bearer_auth(token);
}
Ok(request_data)
}
+50
View File
@@ -0,0 +1,50 @@
use super::oauth::{OAuthProvider, TokenRequestFormat};
pub struct GeminiOAuthProvider;
// TODO: Replace with real credentials after registering Loki with Google Cloud Console
const GEMINI_CLIENT_ID: &str =
"50826443741-upqcebrs4gctqht1f08ku46qlbirkdsj.apps.googleusercontent.com";
const GEMINI_CLIENT_SECRET: &str = "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";
impl OAuthProvider for GeminiOAuthProvider {
fn provider_name(&self) -> &str {
"gemini"
}
fn client_id(&self) -> &str {
GEMINI_CLIENT_ID
}
fn authorize_url(&self) -> &str {
"https://accounts.google.com/o/oauth2/v2/auth"
}
fn token_url(&self) -> &str {
"https://oauth2.googleapis.com/token"
}
fn redirect_uri(&self) -> &str {
""
}
fn scopes(&self) -> &str {
"https://www.googleapis.com/auth/cloud-platform.readonly https://www.googleapis.com/auth/userinfo.email"
}
fn client_secret(&self) -> Option<&str> {
Some(GEMINI_CLIENT_SECRET)
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![("access_type", "offline"), ("prompt", "consent")]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::FormUrlEncoded
}
fn uses_localhost_redirect(&self) -> bool {
true
}
}
+39 -1
View File
@@ -90,7 +90,7 @@ macro_rules! register_client {
pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
return create_config(&$client::PROMPTS, $client::NAME, vault).await
return $client::create_client_config(vault).await
}
)+
if let Some(ret) = create_openai_compatible_client_config(client).await? {
@@ -218,6 +218,44 @@ macro_rules! impl_client_trait {
};
}
#[macro_export]
macro_rules! create_client_config {
($prompts:expr) => {
pub async fn create_client_config(
vault: &$crate::vault::Vault,
) -> anyhow::Result<(String, serde_json::Value)> {
$crate::client::create_config(&$prompts, Self::NAME, vault).await
}
};
}
#[macro_export]
macro_rules! create_oauth_supported_client_config {
() => {
pub async fn create_client_config(vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
let mut config = serde_json::json!({ "type": Self::NAME });
let auth_method = inquire::Select::new(
"Authentication method:",
vec!["API Key", "OAuth"],
)
.prompt()?;
if auth_method == "API Key" {
let env_name = format!("{}_API_KEY", Self::NAME).to_ascii_uppercase();
vault.add_secret(&env_name)?;
config["api_key"] = format!("{{{{{env_name}}}}}").into();
} else {
config["auth"] = "oauth".into();
}
let model = $crate::client::set_client_models_config(&mut config, Self::NAME).await?;
let clients = json!(vec![config]);
Ok((model, clients))
}
}
}
#[macro_export]
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
+3
View File
@@ -1,6 +1,9 @@
mod access_token;
mod claude_oauth;
mod common;
mod gemini_oauth;
mod message;
pub mod oauth;
#[macro_use]
mod macros;
mod model;
+430
View File
@@ -0,0 +1,430 @@
use super::ClientConfig;
use super::access_token::{is_valid_access_token, set_access_token};
use crate::config::Config;
use anyhow::{Result, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use chrono::Utc;
use inquire::Text;
use reqwest::Client as ReqwestClient;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use url::Url;
use uuid::Uuid;
pub enum TokenRequestFormat {
Json,
FormUrlEncoded,
}
pub trait OAuthProvider: Send + Sync {
fn provider_name(&self) -> &str;
fn client_id(&self) -> &str;
fn authorize_url(&self) -> &str;
fn token_url(&self) -> &str;
fn redirect_uri(&self) -> &str;
fn scopes(&self) -> &str;
fn client_secret(&self) -> Option<&str> {
None
}
fn extra_authorize_params(&self) -> Vec<(&str, &str)> {
vec![]
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::Json
}
fn uses_localhost_redirect(&self) -> bool {
false
}
fn extra_token_headers(&self) -> Vec<(&str, &str)> {
vec![]
}
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
vec![]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthTokens {
pub access_token: String,
pub refresh_token: String,
pub expires_at: i64,
}
pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> Result<()> {
let random_bytes: [u8; 32] = rand::random::<[u8; 32]>();
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
let state = Uuid::new_v4().to_string();
let redirect_uri = if provider.uses_localhost_redirect() {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
let uri = format!("http://127.0.0.1:{port}/callback");
// Drop the listener so run_oauth_flow can re-bind below
drop(listener);
uri
} else {
provider.redirect_uri().to_string()
};
let encoded_scopes = urlencoding::encode(provider.scopes());
let encoded_redirect = urlencoding::encode(&redirect_uri);
let mut authorize_url = format!(
"{}?client_id={}&response_type=code&scope={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
provider.authorize_url(),
provider.client_id(),
encoded_scopes,
encoded_redirect,
code_challenge,
state
);
for (key, value) in provider.extra_authorize_params() {
authorize_url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
println!(
"\nOpen this URL to authenticate with {} (client '{}'):\n",
provider.provider_name(),
client_name
);
println!(" {authorize_url}\n");
let _ = open::that(&authorize_url);
let (code, returned_state) = if provider.uses_localhost_redirect() {
listen_for_oauth_callback(&redirect_uri)?
} else {
let input = Text::new("Paste the authorization code:").prompt()?;
let parts: Vec<&str> = input.splitn(2, '#').collect();
if parts.len() != 2 {
bail!("Invalid authorization code format. Expected format: <code>#<state>");
}
(parts[0].to_string(), parts[1].to_string())
};
if returned_state != state {
bail!(
"OAuth state mismatch: expected '{state}', got '{returned_state}'. \
This may indicate a CSRF attack or a stale authorization attempt."
);
}
let client = ReqwestClient::new();
let request = build_token_request(
&client,
provider,
&[
("grant_type", "authorization_code"),
("client_id", provider.client_id()),
("code", &code),
("code_verifier", &code_verifier),
("redirect_uri", &redirect_uri),
("state", &state),
],
);
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing refresh_token in response: {response}"))?
.to_string();
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
let tokens = OAuthTokens {
access_token,
refresh_token,
expires_at,
};
save_oauth_tokens(client_name, &tokens)?;
println!(
"Successfully authenticated client '{}' with {} via OAuth. Tokens saved.",
client_name,
provider.provider_name()
);
Ok(())
}
pub fn load_oauth_tokens(client_name: &str) -> Option<OAuthTokens> {
let path = Config::token_file(client_name);
let content = fs::read_to_string(path).ok()?;
serde_json::from_str(&content).ok()
}
fn save_oauth_tokens(client_name: &str, tokens: &OAuthTokens) -> Result<()> {
let path = Config::token_file(client_name);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let json = serde_json::to_string_pretty(tokens)?;
fs::write(path, json)?;
Ok(())
}
pub async fn refresh_oauth_token(
client: &ReqwestClient,
provider: &impl OAuthProvider,
client_name: &str,
tokens: &OAuthTokens,
) -> Result<OAuthTokens> {
let request = build_token_request(
client,
provider,
&[
("grant_type", "refresh_token"),
("client_id", provider.client_id()),
("refresh_token", &tokens.refresh_token),
],
);
let response: Value = request.send().await?.json().await?;
let access_token = response["access_token"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing access_token in refresh response: {response}"))?
.to_string();
let refresh_token = response["refresh_token"]
.as_str()
.map(|s| s.to_string())
.unwrap_or_else(|| tokens.refresh_token.clone());
let expires_in = response["expires_in"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Missing expires_in in refresh response: {response}"))?;
let expires_at = Utc::now().timestamp() + expires_in;
let new_tokens = OAuthTokens {
access_token,
refresh_token,
expires_at,
};
save_oauth_tokens(client_name, &new_tokens)?;
Ok(new_tokens)
}
pub async fn prepare_oauth_access_token(
client: &ReqwestClient,
provider: &impl OAuthProvider,
client_name: &str,
) -> Result<bool> {
if is_valid_access_token(client_name) {
return Ok(true);
}
let tokens = match load_oauth_tokens(client_name) {
Some(t) => t,
None => return Ok(false),
};
let tokens = if Utc::now().timestamp() >= tokens.expires_at {
refresh_oauth_token(client, provider, client_name, &tokens).await?
} else {
tokens
};
set_access_token(client_name, tokens.access_token.clone(), tokens.expires_at);
Ok(true)
}
fn build_token_request(
client: &ReqwestClient,
provider: &(impl OAuthProvider + ?Sized),
params: &[(&str, &str)],
) -> reqwest::RequestBuilder {
let mut request = match provider.token_request_format() {
TokenRequestFormat::Json => {
let body: serde_json::Map<String, Value> = params
.iter()
.map(|(k, v)| (k.to_string(), Value::String(v.to_string())))
.collect();
if let Some(secret) = provider.client_secret() {
let mut body = body;
body.insert(
"client_secret".to_string(),
Value::String(secret.to_string()),
);
client.post(provider.token_url()).json(&body)
} else {
client.post(provider.token_url()).json(&body)
}
}
TokenRequestFormat::FormUrlEncoded => {
let mut form: HashMap<String, String> = params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
if let Some(secret) = provider.client_secret() {
form.insert("client_secret".to_string(), secret.to_string());
}
client.post(provider.token_url()).form(&form)
}
};
for (key, value) in provider.extra_token_headers() {
request = request.header(key, value);
}
request
}
fn listen_for_oauth_callback(redirect_uri: &str) -> Result<(String, String)> {
let url: Url = redirect_uri.parse()?;
let host = url.host_str().unwrap_or("127.0.0.1");
let port = url
.port()
.ok_or_else(|| anyhow::anyhow!("No port in redirect URI"))?;
let path = url.path();
println!("Waiting for OAuth callback on {redirect_uri} ...\n");
let listener = TcpListener::bind(format!("{host}:{port}"))?;
let (mut stream, _) = listener.accept()?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let request_path = request_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow::anyhow!("Malformed HTTP request from OAuth callback"))?;
let full_url = format!("http://{host}:{port}{request_path}");
let parsed: Url = full_url.parse()?;
let response_body = "<html><body><h2>Authentication successful!</h2><p>You can close this tab and return to your terminal.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes())?;
if !parsed.path().starts_with(path) {
bail!("Unexpected callback path: {}", parsed.path());
}
let code = parsed
.query_pairs()
.find(|(k, _)| k == "code")
.map(|(_, v)| v.to_string())
.ok_or_else(|| {
let error = parsed
.query_pairs()
.find(|(k, _)| k == "error")
.map(|(_, v)| v.to_string())
.unwrap_or_else(|| "unknown".to_string());
anyhow::anyhow!("OAuth callback returned error: {error}")
})?;
let returned_state = parsed
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or_else(|| anyhow::anyhow!("Missing state parameter in OAuth callback"))?;
Ok((code, returned_state))
}
pub fn get_oauth_provider(provider_type: &str) -> Option<Box<dyn OAuthProvider>> {
match provider_type {
"claude" => Some(Box::new(super::claude_oauth::ClaudeOAuthProvider)),
"gemini" => Some(Box::new(super::gemini_oauth::GeminiOAuthProvider)),
_ => None,
}
}
pub fn resolve_provider_type(client_name: &str, clients: &[ClientConfig]) -> Option<&'static str> {
for client_config in clients {
let (config_name, provider_type, auth) = client_config_info(client_config);
if config_name == client_name {
if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() {
return Some(provider_type);
}
return None;
}
}
None
}
pub fn list_oauth_capable_clients(clients: &[ClientConfig]) -> Vec<String> {
clients
.iter()
.filter_map(|client_config| {
let (name, provider_type, auth) = client_config_info(client_config);
if auth == Some("oauth") && get_oauth_provider(provider_type).is_some() {
Some(name.to_string())
} else {
None
}
})
.collect()
}
fn client_config_info(client_config: &ClientConfig) -> (&str, &'static str, Option<&str>) {
match client_config {
ClientConfig::ClaudeConfig(c) => (
c.name.as_deref().unwrap_or("claude"),
"claude",
c.auth.as_deref(),
),
ClientConfig::OpenAIConfig(c) => (c.name.as_deref().unwrap_or("openai"), "openai", None),
ClientConfig::OpenAICompatibleConfig(c) => (
c.name.as_deref().unwrap_or("openai-compatible"),
"openai-compatible",
None,
),
ClientConfig::GeminiConfig(c) => (
c.name.as_deref().unwrap_or("gemini"),
"gemini",
c.auth.as_deref(),
),
ClientConfig::CohereConfig(c) => (c.name.as_deref().unwrap_or("cohere"), "cohere", None),
ClientConfig::AzureOpenAIConfig(c) => (
c.name.as_deref().unwrap_or("azure-openai"),
"azure-openai",
None,
),
ClientConfig::VertexAIConfig(c) => {
(c.name.as_deref().unwrap_or("vertexai"), "vertexai", None)
}
ClientConfig::BedrockConfig(c) => (c.name.as_deref().unwrap_or("bedrock"), "bedrock", None),
ClientConfig::Unknown => ("unknown", "unknown", None),
}
}
+6 -4
View File
@@ -2,10 +2,10 @@ use super::*;
use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
const API_BASE: &str = "https://api.openai.com/v1";
@@ -25,7 +25,7 @@ impl OpenAIClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
create_client_config!([("api_key", "API Key", None, true)]);
}
impl_client_trait!(
@@ -114,7 +114,9 @@ pub async fn openai_chat_completions_streaming(
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
format!(
"Tool call '{function_name}' has non-JSON arguments '{function_arguments}'"
)
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
+1 -1
View File
@@ -21,7 +21,7 @@ impl OpenAICompatibleClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 0] = [];
create_client_config!([]);
}
impl_client_trait!(
+24 -16
View File
@@ -3,11 +3,11 @@ use super::claude::*;
use super::openai::*;
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{Context, Result, anyhow, bail};
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::{path::PathBuf, str::FromStr};
#[derive(Debug, Clone, Deserialize, Default)]
@@ -26,10 +26,10 @@ impl VertexAIClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptAction<'static>; 2] = [
create_client_config!([
("project_id", "Project ID", None, false),
("location", "Location", None, false),
];
]);
}
#[async_trait::async_trait]
@@ -99,9 +99,13 @@ fn prepare_chat_completions(
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
format!(
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
)
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
format!(
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
)
};
let model_name = self_.model.real_name();
@@ -158,9 +162,13 @@ fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<R
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
format!(
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
)
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
format!(
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
)
};
let url = format!(
"{base_url}/google/models/{}:predict",
@@ -220,12 +228,12 @@ pub async fn gemini_chat_completions_streaming(
part["functionCall"]["args"].as_object(),
) {
let thought_signature = part["thoughtSignature"]
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
handler.tool_call(
ToolCall::new(name.to_string(), json!(args), None)
.with_thought_signature(thought_signature),
.with_thought_signature(thought_signature),
)?;
}
}
@@ -288,12 +296,12 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsO
part["functionCall"]["args"].as_object(),
) {
let thought_signature = part["thoughtSignature"]
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
tool_calls.push(
ToolCall::new(name.to_string(), json!(args), None)
.with_thought_signature(thought_signature),
.with_thought_signature(thought_signature),
);
}
}
+8
View File
@@ -428,6 +428,14 @@ impl Config {
base_dir.join(env!("CARGO_CRATE_NAME"))
}
pub fn oauth_tokens_path() -> PathBuf {
Self::cache_path().join("oauth")
}
pub fn token_file(client_name: &str) -> PathBuf {
Self::oauth_tokens_path().join(format!("{client_name}_oauth_tokens.json"))
}
pub fn log_path() -> PathBuf {
Config::cache_path().join(format!("{}.log", env!("CARGO_CRATE_NAME")))
}
+12 -8
View File
@@ -1121,12 +1121,12 @@ pub fn run_llm_function(
envs.insert("FORCE_COLOR".into(), "1".into());
let mut child = Command::new(&cmd_name)
.args(&cmd_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
.args(&cmd_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
let stdout = child.stdout.take().expect("Failed to capture stdout");
let mut stderr = child.stderr.take().expect("Failed to capture stderr");
@@ -1136,7 +1136,9 @@ pub fn run_llm_function(
let mut reader = stdout;
let mut out = io::stdout();
while let Ok(n) = reader.read(&mut buffer) {
if n == 0 { break; }
if n == 0 {
break;
}
let chunk = &buffer[0..n];
let mut last_pos = 0;
for (i, &byte) in chunk.iter().enumerate() {
@@ -1159,7 +1161,9 @@ pub fn run_llm_function(
buf
});
let status = child.wait().map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
let status = child
.wait()
.map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
let _ = stdout_thread.join();
let stderr_bytes = stderr_thread.join().unwrap_or_default();
+42 -3
View File
@@ -16,7 +16,7 @@ mod vault;
extern crate log;
use crate::client::{
ModelType, call_chat_completions, call_chat_completions_streaming, list_models,
ModelType, call_chat_completions, call_chat_completions_streaming, list_models, oauth,
};
use crate::config::{
Agent, CODE_ROLE, Config, EXPLAIN_SHELL_ROLE, GlobalConfig, Input, SHELL_ROLE,
@@ -29,15 +29,17 @@ use crate::utils::*;
use crate::cli::Cli;
use crate::vault::Vault;
use anyhow::{Result, bail};
use anyhow::{Result, anyhow, bail};
use clap::{CommandFactory, Parser};
use clap_complete::CompleteEnv;
use inquire::Text;
use client::ClientConfig;
use inquire::{Select, Text};
use log::LevelFilter;
use log4rs::append::console::ConsoleAppender;
use log4rs::append::file::FileAppender;
use log4rs::config::{Appender, Logger, Root};
use log4rs::encode::pattern::PatternEncoder;
use oauth::OAuthProvider;
use parking_lot::RwLock;
use std::path::PathBuf;
use std::{env, mem, process, sync::Arc};
@@ -81,6 +83,13 @@ async fn main() -> Result<()> {
let log_path = setup_logger()?;
if let Some(client_arg) = &cli.authenticate {
let config = Config::init_bare()?;
let (client_name, provider) = resolve_oauth_client(client_arg.as_deref(), &config.clients)?;
oauth::run_oauth_flow(&*provider, &client_name).await?;
return Ok(());
}
if vault_flags {
return Vault::handle_vault_flags(cli, Config::init_bare()?);
}
@@ -504,3 +513,33 @@ fn init_console_logger(
.build(Root::builder().appender("console").build(root_log_level))
.unwrap()
}
fn resolve_oauth_client(
explicit: Option<&str>,
clients: &[ClientConfig],
) -> Result<(String, Box<dyn OAuthProvider>)> {
if let Some(name) = explicit {
let provider_type = oauth::resolve_provider_type(name, clients)
.ok_or_else(|| anyhow!("Client '{name}' not found or doesn't support OAuth"))?;
let provider = oauth::get_oauth_provider(provider_type).unwrap();
return Ok((name.to_string(), provider));
}
let candidates = oauth::list_oauth_capable_clients(clients);
match candidates.len() {
0 => bail!("No OAuth-capable clients configured."),
1 => {
let name = &candidates[0];
let provider_type = oauth::resolve_provider_type(name, clients).unwrap();
let provider = oauth::get_oauth_provider(provider_type).unwrap();
Ok((name.clone(), provider))
}
_ => {
let choice =
Select::new("Select a client to authenticate:", candidates.clone()).prompt()?;
let provider_type = oauth::resolve_provider_type(&choice, clients).unwrap();
let provider = oauth::get_oauth_provider(provider_type).unwrap();
Ok((choice, provider))
}
}
}
+19 -14
View File
@@ -197,7 +197,8 @@ impl McpRegistry {
}
let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
let ids_to_start: Vec<String> = desired_ids.into_iter()
let ids_to_start: Vec<String> = desired_ids
.into_iter()
.filter(|id| !self.servers.contains_key(id))
.collect();
@@ -301,18 +302,20 @@ impl McpRegistry {
fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
if let Some(config) = &self.config
&& let Some(servers) = enabled_mcp_servers {
if servers == "all" {
config.mcp_servers.keys().cloned().collect()
} else {
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
config.mcp_servers
.keys()
.filter(|id| enabled_servers.contains(*id))
.cloned()
.collect()
}
&& let Some(servers) = enabled_mcp_servers
{
if servers == "all" {
config.mcp_servers.keys().cloned().collect()
} else {
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
config
.mcp_servers
.keys()
.filter(|id| enabled_servers.contains(*id))
.cloned()
.collect()
}
} else {
vec![]
}
@@ -330,7 +333,9 @@ impl McpRegistry {
if let Some(server) = self.servers.remove(&id) {
match Arc::try_unwrap(server) {
Ok(server_inner) => {
server_inner.cancel().await
server_inner
.cancel()
.await
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}");
}
+20 -2
View File
@@ -6,7 +6,7 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt;
use crate::client::{call_chat_completions, call_chat_completions_streaming};
use crate::client::{call_chat_completions, call_chat_completions_streaming, init_client, oauth};
use crate::config::{
AgentVariables, AssertState, Config, GlobalConfig, Input, LastMessage, StateFlags,
macro_execute,
@@ -17,6 +17,7 @@ use crate::utils::{
};
use crate::mcp::McpRegistry;
use crate::resolve_oauth_client;
use anyhow::{Context, Result, bail};
use crossterm::cursor::SetCursorStyle;
use fancy_regex::Regex;
@@ -32,10 +33,15 @@ use std::{env, mem, process};
const MENU_NAME: &str = "completion_menu";
static REPL_COMMANDS: LazyLock<[ReplCommand; 37]> = LazyLock::new(|| {
static REPL_COMMANDS: LazyLock<[ReplCommand; 38]> = LazyLock::new(|| {
[
ReplCommand::new(".help", "Show this help guide", AssertState::pass()),
ReplCommand::new(".info", "Show system info", AssertState::pass()),
ReplCommand::new(
".authenticate",
"Authenticate the current model client via OAuth (if configured)",
AssertState::pass(),
),
ReplCommand::new(
".edit config",
"Modify configuration file",
@@ -421,6 +427,18 @@ pub async fn run_repl_command(
}
None => println!("Usage: .model <name>"),
},
".authenticate" => {
let client = init_client(config, None)?;
if !client.supports_oauth() {
bail!(
"Client '{}' doesn't either support OAuth or isn't configured to use it (i.e. uses an API key instead)",
client.name()
);
}
let clients = config.read().clients.clone();
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
oauth::run_oauth_flow(&*provider, &client_name).await?;
}
".prompt" => match args {
Some(text) => {
config.write().use_prompt(text)?;
+22 -11
View File
@@ -6,7 +6,6 @@ use gman::providers::local::LocalProvider;
use indoc::formatdoc;
use inquire::validator::Validation;
use inquire::{Confirm, Password, PasswordDisplayMode, Text, min_length, required};
use std::borrow::Cow;
use std::path::PathBuf;
pub fn ensure_password_file_initialized(local_provider: &mut LocalProvider) -> Result<()> {
@@ -166,18 +165,30 @@ pub fn create_vault_password_file(vault: &mut Vault) -> Result<()> {
Ok(())
}
pub fn interpolate_secrets<'a>(content: &'a str, vault: &Vault) -> (Cow<'a, str>, Vec<String>) {
pub fn interpolate_secrets(content: &str, vault: &Vault) -> (String, Vec<String>) {
let mut missing_secrets = vec![];
let parsed_content = SECRET_RE.replace_all(content, |caps: &fancy_regex::Captures<'_>| {
let secret = vault.get_secret(caps[1].trim(), false);
match secret {
Ok(s) => s,
Err(_) => {
missing_secrets.push(caps[1].to_string());
"".to_string()
let parsed_content: String = content
.lines()
.map(|line| {
if line.trim_start().starts_with('#') {
return line.to_string();
}
}
});
SECRET_RE
.replace_all(line, |caps: &fancy_regex::Captures<'_>| {
let secret = vault.get_secret(caps[1].trim(), false);
match secret {
Ok(s) => s,
Err(_) => {
missing_secrets.push(caps[1].to_string());
"".to_string()
}
}
})
.to_string()
})
.collect::<Vec<_>>()
.join("\n");
(parsed_content, missing_secrets)
}