feat: Allow first-runs to select OAuth for supported providers
This commit is contained in:
@@ -18,16 +18,16 @@ pub struct AzureOpenAIConfig {
|
|||||||
impl AzureOpenAIClient {
|
impl AzureOpenAIClient {
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
create_client_config!([
|
||||||
(
|
(
|
||||||
"api_base",
|
"api_base",
|
||||||
"API Base",
|
"API Base",
|
||||||
Some("e.g. https://{RESOURCE}.openai.azure.com"),
|
Some("e.g. https://{RESOURCE}.openai.azure.com"),
|
||||||
false
|
false,
|
||||||
),
|
),
|
||||||
("api_key", "API Key", None, true),
|
("api_key", "API Key", None, true),
|
||||||
];
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_client_trait!(
|
impl_client_trait!(
|
||||||
|
|||||||
@@ -32,11 +32,11 @@ impl BedrockClient {
|
|||||||
config_get_fn!(region, get_region);
|
config_get_fn!(region, get_region);
|
||||||
config_get_fn!(session_token, get_session_token);
|
config_get_fn!(session_token, get_session_token);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 3] = [
|
create_client_config!([
|
||||||
("access_key_id", "AWS Access Key ID", None, true),
|
("access_key_id", "AWS Access Key ID", None, true),
|
||||||
("secret_access_key", "AWS Secret Access Key", None, true),
|
("secret_access_key", "AWS Secret Access Key", None, true),
|
||||||
("region", "AWS Region", None, false),
|
("region", "AWS Region", None, false),
|
||||||
];
|
]);
|
||||||
|
|
||||||
fn chat_completions_builder(
|
fn chat_completions_builder(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ pub struct ClaudeConfig {
|
|||||||
impl ClaudeClient {
|
impl ClaudeClient {
|
||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
|
create_oauth_supported_client_config!();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ impl CohereClient {
|
|||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
|
create_client_config!([("api_key", "API Key", None, true)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_client_trait!(
|
impl_client_trait!(
|
||||||
|
|||||||
@@ -546,7 +546,7 @@ pub fn json_str_from_map<'a>(
|
|||||||
map.get(field_name).and_then(|v| v.as_str())
|
map.get(field_name).and_then(|v| v.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
|
pub async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
|
||||||
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
|
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
|
||||||
let models: Vec<String> = provider
|
let models: Vec<String> = provider
|
||||||
.models
|
.models
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use super::*;
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use reqwest::RequestBuilder;
|
use reqwest::RequestBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
|
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ impl GeminiClient {
|
|||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
|
create_client_config!([("api_key", "API Key", None, true)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_client_trait!(
|
impl_client_trait!(
|
||||||
|
|||||||
+39
-1
@@ -90,7 +90,7 @@ macro_rules! register_client {
|
|||||||
pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
|
pub async fn create_client_config(client: &str, vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
|
||||||
$(
|
$(
|
||||||
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
|
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
|
||||||
return create_config(&$client::PROMPTS, $client::NAME, vault).await
|
return $client::create_client_config(vault).await
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
if let Some(ret) = create_openai_compatible_client_config(client).await? {
|
if let Some(ret) = create_openai_compatible_client_config(client).await? {
|
||||||
@@ -218,6 +218,44 @@ macro_rules! impl_client_trait {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! create_client_config {
|
||||||
|
($prompts:expr) => {
|
||||||
|
pub async fn create_client_config(
|
||||||
|
vault: &$crate::vault::Vault,
|
||||||
|
) -> anyhow::Result<(String, serde_json::Value)> {
|
||||||
|
$crate::client::create_config(&$prompts, Self::NAME, vault).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! create_oauth_supported_client_config {
|
||||||
|
() => {
|
||||||
|
pub async fn create_client_config(vault: &$crate::vault::Vault) -> anyhow::Result<(String, serde_json::Value)> {
|
||||||
|
let mut config = serde_json::json!({ "type": Self::NAME });
|
||||||
|
|
||||||
|
let auth_method = inquire::Select::new(
|
||||||
|
"Authentication method:",
|
||||||
|
vec!["API Key", "OAuth"],
|
||||||
|
)
|
||||||
|
.prompt()?;
|
||||||
|
|
||||||
|
if auth_method == "API Key" {
|
||||||
|
let env_name = format!("{}_API_KEY", Self::NAME).to_ascii_uppercase();
|
||||||
|
vault.add_secret(&env_name)?;
|
||||||
|
config["api_key"] = format!("{{{{{env_name}}}}}").into();
|
||||||
|
} else {
|
||||||
|
config["auth"] = "oauth".into();
|
||||||
|
}
|
||||||
|
|
||||||
|
let model = $crate::client::set_client_models_config(&mut config, Self::NAME).await?;
|
||||||
|
let clients = json!(vec![config]);
|
||||||
|
Ok((model, clients))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! config_get_fn {
|
macro_rules! config_get_fn {
|
||||||
($field_name:ident, $fn_name:ident) => {
|
($field_name:ident, $fn_name:ident) => {
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ use super::*;
|
|||||||
|
|
||||||
use crate::utils::strip_think_tag;
|
use crate::utils::strip_think_tag;
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{Context, Result, bail};
|
||||||
use reqwest::RequestBuilder;
|
use reqwest::RequestBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
const API_BASE: &str = "https://api.openai.com/v1";
|
const API_BASE: &str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ impl OpenAIClient {
|
|||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None, true)];
|
create_client_config!([("api_key", "API Key", None, true)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_client_trait!(
|
impl_client_trait!(
|
||||||
@@ -114,7 +114,9 @@ pub async fn openai_chat_completions_streaming(
|
|||||||
function_arguments = String::from("{}");
|
function_arguments = String::from("{}");
|
||||||
}
|
}
|
||||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
format!(
|
||||||
|
"Tool call '{function_name}' has non-JSON arguments '{function_arguments}'"
|
||||||
|
)
|
||||||
})?;
|
})?;
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ impl OpenAICompatibleClient {
|
|||||||
config_get_fn!(api_base, get_api_base);
|
config_get_fn!(api_base, get_api_base);
|
||||||
config_get_fn!(api_key, get_api_key);
|
config_get_fn!(api_key, get_api_key);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 0] = [];
|
create_client_config!([]);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_client_trait!(
|
impl_client_trait!(
|
||||||
|
|||||||
+24
-16
@@ -3,11 +3,11 @@ use super::claude::*;
|
|||||||
use super::openai::*;
|
use super::openai::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Context, Result};
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use std::{path::PathBuf, str::FromStr};
|
use std::{path::PathBuf, str::FromStr};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
@@ -26,10 +26,10 @@ impl VertexAIClient {
|
|||||||
config_get_fn!(project_id, get_project_id);
|
config_get_fn!(project_id, get_project_id);
|
||||||
config_get_fn!(location, get_location);
|
config_get_fn!(location, get_location);
|
||||||
|
|
||||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
create_client_config!([
|
||||||
("project_id", "Project ID", None, false),
|
("project_id", "Project ID", None, false),
|
||||||
("location", "Location", None, false),
|
("location", "Location", None, false),
|
||||||
];
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
@@ -99,9 +99,13 @@ fn prepare_chat_completions(
|
|||||||
let access_token = get_access_token(self_.name())?;
|
let access_token = get_access_token(self_.name())?;
|
||||||
|
|
||||||
let base_url = if location == "global" {
|
let base_url = if location == "global" {
|
||||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
format!(
|
||||||
|
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
format!(
|
||||||
|
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_name = self_.model.real_name();
|
let model_name = self_.model.real_name();
|
||||||
@@ -158,9 +162,13 @@ fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<R
|
|||||||
let access_token = get_access_token(self_.name())?;
|
let access_token = get_access_token(self_.name())?;
|
||||||
|
|
||||||
let base_url = if location == "global" {
|
let base_url = if location == "global" {
|
||||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
format!(
|
||||||
|
"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers"
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
format!(
|
||||||
|
"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers"
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"{base_url}/google/models/{}:predict",
|
"{base_url}/google/models/{}:predict",
|
||||||
@@ -220,12 +228,12 @@ pub async fn gemini_chat_completions_streaming(
|
|||||||
part["functionCall"]["args"].as_object(),
|
part["functionCall"]["args"].as_object(),
|
||||||
) {
|
) {
|
||||||
let thought_signature = part["thoughtSignature"]
|
let thought_signature = part["thoughtSignature"]
|
||||||
.as_str()
|
.as_str()
|
||||||
.or_else(|| part["thought_signature"].as_str())
|
.or_else(|| part["thought_signature"].as_str())
|
||||||
.map(|s| s.to_string());
|
.map(|s| s.to_string());
|
||||||
handler.tool_call(
|
handler.tool_call(
|
||||||
ToolCall::new(name.to_string(), json!(args), None)
|
ToolCall::new(name.to_string(), json!(args), None)
|
||||||
.with_thought_signature(thought_signature),
|
.with_thought_signature(thought_signature),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,12 +296,12 @@ fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsO
|
|||||||
part["functionCall"]["args"].as_object(),
|
part["functionCall"]["args"].as_object(),
|
||||||
) {
|
) {
|
||||||
let thought_signature = part["thoughtSignature"]
|
let thought_signature = part["thoughtSignature"]
|
||||||
.as_str()
|
.as_str()
|
||||||
.or_else(|| part["thought_signature"].as_str())
|
.or_else(|| part["thought_signature"].as_str())
|
||||||
.map(|s| s.to_string());
|
.map(|s| s.to_string());
|
||||||
tool_calls.push(
|
tool_calls.push(
|
||||||
ToolCall::new(name.to_string(), json!(args), None)
|
ToolCall::new(name.to_string(), json!(args), None)
|
||||||
.with_thought_signature(thought_signature),
|
.with_thought_signature(thought_signature),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user