feat: Allow first-runs to select OAuth for supported providers

This commit is contained in:
2026-03-11 12:01:17 -06:00
parent 3fa0eb832c
commit 03b9cc70b9
10 changed files with 82 additions and 34 deletions
+24 -16
View File
@@ -3,11 +3,11 @@ use super::claude::*;
use super::openai::*;
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{Context, Result, anyhow, bail};
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::{path::PathBuf, str::FromStr};
#[derive(Debug, Clone, Deserialize, Default)]
@@ -26,10 +26,10 @@ impl VertexAIClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptAction<'static>; 2] = [
create_client_config!([
("project_id", "Project ID", None, false),
("location", "Location", None, false),
];
]);
}
#[async_trait::async_trait]
@@ -99,9 +99,13 @@ fn prepare_chat_completions(
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
format!(
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
)
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
format!(
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
)
};
let model_name = self_.model.real_name();
@@ -158,9 +162,13 @@ fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<R
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
format!(
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
)
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
format!(
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
)
};
let url = format!(
"{base_url}/google/models/{}:predict",
@@ -220,12 +228,12 @@ pub async fn gemini_chat_completions_streaming(
part["functionCall"]["args"].as_object(),
) {
let thought_signature = part["thoughtSignature"]
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
handler.tool_call(
ToolCall::new(name.to_string(), json!(args), None)
.with_thought_signature(thought_signature),
.with_thought_signature(thought_signature),
)?;
}
}
@@ -288,12 +296,12 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsO
part["functionCall"]["args"].as_object(),
) {
let thought_signature = part["thoughtSignature"]
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
.as_str()
.or_else(|| part["thought_signature"].as_str())
.map(|s| s.to_string());
tool_calls.push(
ToolCall::new(name.to_string(), json!(args), None)
.with_thought_signature(thought_signature),
.with_thought_signature(thought_signature),
);
}
}