Baseline project
This commit is contained in:
+219
@@ -0,0 +1,219 @@
|
||||
use crate::client::{list_models, ModelType};
|
||||
use crate::config::{list_agents, Config};
|
||||
use anyhow::{Context, Result};
|
||||
use clap::ValueHint;
|
||||
use clap::{crate_authors, crate_description, crate_name, crate_version, Parser};
|
||||
use clap_complete::ArgValueCompleter;
|
||||
use clap_complete::CompletionCandidate;
|
||||
use is_terminal::IsTerminal;
|
||||
use std::ffi::OsStr;
|
||||
use std::io::{stdin, Read};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
#[command(
|
||||
name = crate_name!(),
|
||||
author = crate_authors!(),
|
||||
version = crate_version!(),
|
||||
about = crate_description!(),
|
||||
help_template = "\
|
||||
{before-help}{name} {version}
|
||||
{author-with-newline}
|
||||
{about-with-newline}
|
||||
{usage-heading} {usage}
|
||||
|
||||
{all-args}{after-help}
|
||||
"
|
||||
)]
|
||||
pub struct Cli {
|
||||
/// Select a LLM model
|
||||
#[arg(short, long, add = ArgValueCompleter::new(model_completer))]
|
||||
pub model: Option<String>,
|
||||
/// Use the system prompt
|
||||
#[arg(long)]
|
||||
pub prompt: Option<String>,
|
||||
/// Select a role
|
||||
#[arg(short, long, add = ArgValueCompleter::new(role_completer))]
|
||||
pub role: Option<String>,
|
||||
/// Start or join a session
|
||||
#[arg(short = 's', long, add = ArgValueCompleter::new(session_completer))]
|
||||
pub session: Option<Option<String>>,
|
||||
/// Ensure the session is empty
|
||||
#[arg(long)]
|
||||
pub empty_session: bool,
|
||||
/// Ensure the new conversation is saved to the session
|
||||
#[arg(long)]
|
||||
pub save_session: bool,
|
||||
/// Start an agent
|
||||
#[arg(short = 'a', long, add = ArgValueCompleter::new(agent_completer))]
|
||||
pub agent: Option<String>,
|
||||
/// Set agent variables
|
||||
#[arg(long, value_names = ["NAME", "VALUE"], num_args = 2)]
|
||||
pub agent_variable: Vec<String>,
|
||||
/// Start a RAG
|
||||
#[arg(long, add = ArgValueCompleter::new(rag_completer))]
|
||||
pub rag: Option<String>,
|
||||
/// Rebuild the RAG to sync document changes
|
||||
#[arg(long)]
|
||||
pub rebuild_rag: bool,
|
||||
/// Execute a macro
|
||||
#[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))]
|
||||
pub macro_name: Option<String>,
|
||||
/// Serve the LLM API and WebAPP
|
||||
#[arg(long, value_name = "PORT|IP|IP:PORT")]
|
||||
pub serve: Option<Option<String>>,
|
||||
/// Execute commands in natural language
|
||||
#[arg(short = 'e', long)]
|
||||
pub execute: bool,
|
||||
/// Output code only
|
||||
#[arg(short = 'c', long)]
|
||||
pub code: bool,
|
||||
/// Include files, directories, or URLs
|
||||
#[arg(short = 'f', long, value_name = "FILE|URL", value_hint = ValueHint::AnyPath)]
|
||||
pub file: Vec<String>,
|
||||
/// Turn off stream mode
|
||||
#[arg(short = 'S', long)]
|
||||
pub no_stream: bool,
|
||||
/// Display the message without sending it
|
||||
#[arg(long)]
|
||||
pub dry_run: bool,
|
||||
/// Display information
|
||||
#[arg(long)]
|
||||
pub info: bool,
|
||||
/// Build all configured Bash tool scripts
|
||||
#[arg(long)]
|
||||
pub build_tools: bool,
|
||||
/// Sync models updates
|
||||
#[arg(long)]
|
||||
pub sync_models: bool,
|
||||
/// List all available chat models
|
||||
#[arg(long)]
|
||||
pub list_models: bool,
|
||||
/// List all roles
|
||||
#[arg(long)]
|
||||
pub list_roles: bool,
|
||||
/// List all sessions
|
||||
#[arg(long)]
|
||||
pub list_sessions: bool,
|
||||
/// List all agents
|
||||
#[arg(long)]
|
||||
pub list_agents: bool,
|
||||
/// List all RAGs
|
||||
#[arg(long)]
|
||||
pub list_rags: bool,
|
||||
/// List all macros
|
||||
#[arg(long)]
|
||||
pub list_macros: bool,
|
||||
/// Input text
|
||||
#[arg(trailing_var_arg = true)]
|
||||
text: Vec<String>,
|
||||
/// Tail logs
|
||||
#[arg(long)]
|
||||
pub tail_logs: bool,
|
||||
/// Disable colored log output
|
||||
#[arg(long, requires = "tail_logs")]
|
||||
pub disable_log_colors: bool,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn text(&self) -> Result<Option<String>> {
|
||||
let mut stdin_text = String::new();
|
||||
if !stdin().is_terminal() {
|
||||
let _ = stdin()
|
||||
.read_to_string(&mut stdin_text)
|
||||
.context("Invalid stdin pipe")?;
|
||||
};
|
||||
match self.text.is_empty() {
|
||||
true => {
|
||||
if stdin_text.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(stdin_text))
|
||||
}
|
||||
}
|
||||
false => {
|
||||
if self.macro_name.is_some() {
|
||||
let text = self
|
||||
.text
|
||||
.iter()
|
||||
.map(|v| shell_words::quote(v))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
if stdin_text.is_empty() {
|
||||
Ok(Some(text))
|
||||
} else {
|
||||
Ok(Some(format!("{text} -- {stdin_text}")))
|
||||
}
|
||||
} else {
|
||||
let text = self.text.join(" ");
|
||||
if stdin_text.is_empty() {
|
||||
Ok(Some(text))
|
||||
} else {
|
||||
Ok(Some(format!("{text}\n{stdin_text}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn model_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
match Config::init_bare() {
|
||||
Ok(config) => list_models(&config, ModelType::Chat)
|
||||
.into_iter()
|
||||
.filter(|&m| m.id().starts_with(&*cur))
|
||||
.map(|m| CompletionCandidate::new(m.id()))
|
||||
.collect(),
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn role_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_roles(true)
|
||||
.into_iter()
|
||||
.filter(|r| r.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn agent_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
list_agents()
|
||||
.into_iter()
|
||||
.filter(|a| a.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn rag_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_rags()
|
||||
.into_iter()
|
||||
.filter(|r| r.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn macro_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
Config::list_macros()
|
||||
.into_iter()
|
||||
.filter(|m| m.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn session_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
match Config::init_bare() {
|
||||
Ok(config) => config
|
||||
.list_sessions()
|
||||
.into_iter()
|
||||
.filter(|s| s.starts_with(&*cur))
|
||||
.map(CompletionCandidate::new)
|
||||
.collect(),
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use indexmap::IndexMap;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static ACCESS_TOKENS: LazyLock<RwLock<IndexMap<String, (String, i64)>>> =
|
||||
LazyLock::new(|| RwLock::new(IndexMap::new()));
|
||||
|
||||
pub fn get_access_token(client_name: &str) -> Result<String> {
|
||||
ACCESS_TOKENS
|
||||
.read()
|
||||
.get(client_name)
|
||||
.map(|(token, _)| token.clone())
|
||||
.ok_or_else(|| anyhow!("Invalid access token"))
|
||||
}
|
||||
|
||||
pub fn is_valid_access_token(client_name: &str) -> bool {
|
||||
let access_tokens = ACCESS_TOKENS.read();
|
||||
let (token, expires_at) = match access_tokens.get(client_name) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
!token.is_empty() && Utc::now().timestamp() < *expires_at
|
||||
}
|
||||
|
||||
pub fn set_access_token(client_name: &str, token: String, expires_at: i64) {
|
||||
let mut access_tokens = ACCESS_TOKENS.write();
|
||||
let entry = access_tokens.entry(client_name.to_string()).or_default();
|
||||
entry.0 = token;
|
||||
entry.1 = expires_at;
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AzureOpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
||||
(
|
||||
"api_base",
|
||||
"API Base",
|
||||
Some("e.g. https://{RESOURCE}.openai.azure.com"),
|
||||
),
|
||||
("api_key", "API Key", None),
|
||||
];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
AzureOpenAIClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &AzureOpenAIClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_base = self_.get_api_base()?;
|
||||
let api_key = self_.get_api_key()?;
|
||||
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version=2024-12-01-preview",
|
||||
&api_base,
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_base = self_.get_api_base()?;
|
||||
let api_key = self_.get_api_key()?;
|
||||
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/embeddings?api-version=2024-10-21",
|
||||
&api_base,
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
@@ -0,0 +1,643 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use aws_smithy_eventstream::smithy::parse_response_headers;
|
||||
use bytes::BytesMut;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures_util::StreamExt;
|
||||
use indexmap::IndexMap;
|
||||
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BedrockConfig {
|
||||
pub name: Option<String>,
|
||||
pub access_key_id: Option<String>,
|
||||
pub secret_access_key: Option<String>,
|
||||
pub region: Option<String>,
|
||||
pub session_token: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl BedrockClient {
|
||||
config_get_fn!(access_key_id, get_access_key_id);
|
||||
config_get_fn!(secret_access_key, get_secret_access_key);
|
||||
config_get_fn!(region, get_region);
|
||||
config_get_fn!(session_token, get_session_token);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 3] = [
|
||||
("access_key_id", "AWS Access Key ID", None),
|
||||
("secret_access_key", "AWS Secret Access Key", None),
|
||||
("region", "AWS Region", None),
|
||||
];
|
||||
|
||||
fn chat_completions_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
let secret_access_key = self.get_secret_access_key()?;
|
||||
let region = self.get_region()?;
|
||||
let session_token = self.get_session_token().ok();
|
||||
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
||||
|
||||
let model_name = &self.model.real_name();
|
||||
|
||||
let uri = if data.stream {
|
||||
format!("/model/{model_name}/converse-stream")
|
||||
} else {
|
||||
format!("/model/{model_name}/converse")
|
||||
};
|
||||
|
||||
let body = build_chat_completions_body(data, &self.model)?;
|
||||
|
||||
let mut request_data = RequestData::new("", body);
|
||||
self.patch_request_data(&mut request_data);
|
||||
let RequestData {
|
||||
url: _,
|
||||
headers,
|
||||
body,
|
||||
} = request_data;
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
&AwsCredentials {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
region,
|
||||
session_token,
|
||||
},
|
||||
AwsRequest {
|
||||
method: Method::POST,
|
||||
host,
|
||||
service: "bedrock".into(),
|
||||
uri,
|
||||
querystring: "".into(),
|
||||
headers,
|
||||
body: body.to_string(),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn embeddings_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
let secret_access_key = self.get_secret_access_key()?;
|
||||
let region = self.get_region()?;
|
||||
let session_token = self.get_session_token().ok();
|
||||
let host = format!("bedrock-runtime.{region}.amazonaws.com");
|
||||
|
||||
let uri = format!("/model/{}/invoke", self.model.real_name());
|
||||
|
||||
let input_type = match data.query {
|
||||
true => "search_query",
|
||||
false => "search_document",
|
||||
};
|
||||
|
||||
let body = json!({
|
||||
"texts": data.texts,
|
||||
"input_type": input_type,
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new("", body);
|
||||
self.patch_request_data(&mut request_data);
|
||||
let RequestData {
|
||||
url: _,
|
||||
headers,
|
||||
body,
|
||||
} = request_data;
|
||||
|
||||
let builder = aws_fetch(
|
||||
client,
|
||||
&AwsCredentials {
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
region,
|
||||
session_token,
|
||||
},
|
||||
AwsRequest {
|
||||
method: Method::POST,
|
||||
host,
|
||||
service: "bedrock".into(),
|
||||
uri,
|
||||
querystring: "".into(),
|
||||
headers,
|
||||
body: body.to_string(),
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Client for BedrockClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let builder = self.chat_completions_builder(client, data)?;
|
||||
chat_completions(builder).await
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
let builder = self.chat_completions_builder(client, data)?;
|
||||
chat_completions_streaming(builder, handler).await
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let builder = self.embeddings_builder(client, data)?;
|
||||
embeddings(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
if !status.is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
catch_error(&data, status.as_u16())?;
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
|
||||
let mut stream = res.bytes_stream();
|
||||
let mut buffer = BytesMut::new();
|
||||
let mut decoder = MessageFrameDecoder::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
buffer.extend_from_slice(&chunk);
|
||||
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
|
||||
let response_headers = parse_response_headers(&message)?;
|
||||
let message_type = response_headers.message_type.as_str();
|
||||
let smithy_type = response_headers.smithy_type.as_str();
|
||||
match (message_type, smithy_type) {
|
||||
("event", _) => {
|
||||
let data: Value = serde_json::from_slice(message.payload())?;
|
||||
debug!("stream-data: {smithy_type} {data}");
|
||||
match smithy_type {
|
||||
"contentBlockStart" => {
|
||||
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
) {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value =
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_arguments.clear();
|
||||
function_name = name.into();
|
||||
function_id = id.into();
|
||||
}
|
||||
}
|
||||
}
|
||||
"contentBlockDelta" => {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) =
|
||||
data["delta"]["reasoningContent"]["text"].as_str()
|
||||
{
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
|
||||
function_arguments.push_str(input);
|
||||
}
|
||||
}
|
||||
"contentBlockStop" => {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
("exception", _) => {
|
||||
let payload = base64_decode(message.payload())?;
|
||||
let data = String::from_utf8_lossy(&payload);
|
||||
|
||||
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
|
||||
}
|
||||
_ => {
|
||||
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
Ok(res_body.embeddings)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream: _,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": [ { "text": strip_think_tag(&text) } ] })]
|
||||
}
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"content": [
|
||||
{
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let content: Vec<_> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
json!({"text": text})
|
||||
}
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
if let Some((mime_type, data)) = url
|
||||
.strip_prefix("data:")
|
||||
.and_then(|v| v.split_once(";base64,"))
|
||||
{
|
||||
json!({
|
||||
"image": {
|
||||
"format": mime_type.replace("image/", ""),
|
||||
"source": {
|
||||
"bytes": data,
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
vec![json!({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut assistant_parts = vec![];
|
||||
let mut user_parts = vec![];
|
||||
if !text.is_empty() {
|
||||
assistant_parts.push(json!({
|
||||
"text": text,
|
||||
}))
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
assistant_parts.push(json!({
|
||||
"toolUse": {
|
||||
"toolUseId": tool_result.call.id,
|
||||
"name": tool_result.call.name,
|
||||
"input": tool_result.call.arguments,
|
||||
}
|
||||
}));
|
||||
user_parts.push(json!({
|
||||
"toolResult": {
|
||||
"toolUseId": tool_result.call.id,
|
||||
"content": [
|
||||
{
|
||||
"json": tool_result.output,
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
}
|
||||
vec![
|
||||
json!({
|
||||
"role": "assistant",
|
||||
"content": assistant_parts,
|
||||
}),
|
||||
json!({
|
||||
"role": "user",
|
||||
"content": user_parts,
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"inferenceConfig": {},
|
||||
"messages": messages,
|
||||
});
|
||||
if let Some(v) = system_message {
|
||||
body["system"] = json!([
|
||||
{
|
||||
"text": v,
|
||||
}
|
||||
])
|
||||
}
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["inferenceConfig"]["maxTokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["inferenceConfig"]["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["inferenceConfig"]["topP"] = v.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
let tools: Vec<_> = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"toolSpec": {
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"inputSchema": {
|
||||
"json": v.parameters,
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
body["toolConfig"] = json!({
|
||||
"tools": tools,
|
||||
})
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = String::new();
|
||||
let mut reasoning = None;
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(array) = data["output"]["message"]["content"].as_array() {
|
||||
for item in array {
|
||||
if let Some(v) = item["text"].as_str() {
|
||||
if !text.is_empty() {
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
text.push_str(v);
|
||||
} else if let Some(reasoning_text) =
|
||||
item["reasoningContent"]["reasoningText"].as_object()
|
||||
{
|
||||
if let Some(text) = json_str_from_map(reasoning_text, "text") {
|
||||
reasoning = Some(text.to_string());
|
||||
}
|
||||
} else if let Some(tool_use) = item["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name), Some(input)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
tool_use.get("input"),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = reasoning {
|
||||
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: None,
|
||||
input_tokens: data["usage"]["inputTokens"].as_u64(),
|
||||
output_tokens: data["usage"]["outputTokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsCredentials {
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
region: String,
|
||||
session_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AwsRequest {
|
||||
method: Method,
|
||||
host: String,
|
||||
service: String,
|
||||
uri: String,
|
||||
querystring: String,
|
||||
headers: IndexMap<String, String>,
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn aws_fetch(
|
||||
client: &ReqwestClient,
|
||||
credentials: &AwsCredentials,
|
||||
request: AwsRequest,
|
||||
) -> Result<RequestBuilder> {
|
||||
let AwsRequest {
|
||||
method,
|
||||
host,
|
||||
service,
|
||||
uri,
|
||||
querystring,
|
||||
mut headers,
|
||||
body,
|
||||
} = request;
|
||||
let region = &credentials.region;
|
||||
|
||||
let endpoint = format!("https://{host}{uri}");
|
||||
|
||||
let now: DateTime<Utc> = Utc::now();
|
||||
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
|
||||
let date_stamp = amz_date[0..8].to_string();
|
||||
headers.insert("host".into(), host.clone());
|
||||
headers.insert("x-amz-date".into(), amz_date.clone());
|
||||
if let Some(token) = credentials.session_token.clone() {
|
||||
headers.insert("x-amz-security-token".into(), token);
|
||||
}
|
||||
|
||||
let canonical_headers = headers
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{key}:{value}\n"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
let signed_headers = headers
|
||||
.iter()
|
||||
.map(|(key, _)| key.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(";");
|
||||
|
||||
let payload_hash = sha256(&body);
|
||||
|
||||
let canonical_request = format!(
|
||||
"{}\n{}\n{}\n{}\n{}\n{}",
|
||||
method,
|
||||
encode_uri(&uri),
|
||||
querystring,
|
||||
canonical_headers,
|
||||
signed_headers,
|
||||
payload_hash
|
||||
);
|
||||
|
||||
let algorithm = "AWS4-HMAC-SHA256";
|
||||
let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request");
|
||||
let string_to_sign = format!(
|
||||
"{}\n{}\n{}\n{}",
|
||||
algorithm,
|
||||
amz_date,
|
||||
credential_scope,
|
||||
sha256(&canonical_request)
|
||||
);
|
||||
|
||||
let signing_key = gen_signing_key(
|
||||
&credentials.secret_access_key,
|
||||
&date_stamp,
|
||||
region,
|
||||
&service,
|
||||
);
|
||||
let signature = hmac_sha256(&signing_key, &string_to_sign);
|
||||
let signature = hex_encode(&signature);
|
||||
|
||||
let authorization_header = format!(
|
||||
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
|
||||
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
|
||||
);
|
||||
|
||||
headers.insert("authorization".into(), authorization_header);
|
||||
|
||||
debug!("Request {endpoint} {body}");
|
||||
|
||||
let mut request_builder = client.request(method, endpoint).body(body);
|
||||
|
||||
for (key, value) in &headers {
|
||||
request_builder = request_builder.header(key, value);
|
||||
}
|
||||
|
||||
Ok(request_builder)
|
||||
}
|
||||
|
||||
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
|
||||
let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp);
|
||||
let k_region = hmac_sha256(&k_date, region);
|
||||
let k_service = hmac_sha256(&k_region, service);
|
||||
hmac_sha256(&k_service, "aws4_request")
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::strip_think_tag;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.anthropic.com/v1";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ClaudeConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl ClaudeClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
ClaudeClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
claude_chat_completions,
|
||||
claude_chat_completions_streaming
|
||||
),
|
||||
(noop_prepare_embeddings, noop_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &ClaudeClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/messages", api_base.trim_end_matches('/'));
|
||||
let body = claude_build_chat_completions_body(data, &self_.model)?;
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("anthropic-version", "2023-06-01");
|
||||
request_data.header("x-api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn claude_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
debug!("non-stream-data: {data}");
|
||||
claude_extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
pub async fn claude_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(typ) = data["type"].as_str() {
|
||||
match typ {
|
||||
"content_block_start" => {
|
||||
if let (Some("tool_use"), Some(name), Some(id)) = (
|
||||
data["content_block"]["type"].as_str(),
|
||||
data["content_block"]["name"].as_str(),
|
||||
data["content_block"]["id"].as_str(),
|
||||
) {
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value =
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_name = name.into();
|
||||
function_arguments.clear();
|
||||
function_id = id.into();
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) = data["delta"]["thinking"].as_str() {
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let (true, Some(partial_json)) = (
|
||||
!function_name.is_empty(),
|
||||
data["delta"]["partial_json"].as_str(),
|
||||
) {
|
||||
function_arguments.push_str(partial_json);
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value = if function_arguments.is_empty() {
|
||||
json!({})
|
||||
} else {
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?
|
||||
};
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub fn claude_build_chat_completions_body(
|
||||
data: ChatCompletionsData,
|
||||
model: &Model,
|
||||
) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": strip_think_tag(&text) })]
|
||||
}
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"content": text,
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let content: Vec<_> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
json!({"type": "text", "text": text})
|
||||
}
|
||||
MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
} => {
|
||||
if let Some((mime_type, data)) = url
|
||||
.strip_prefix("data:")
|
||||
.and_then(|v| v.split_once(";base64,"))
|
||||
{
|
||||
json!({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": data,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
vec![json!({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})]
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut assistant_parts = vec![];
|
||||
let mut user_parts = vec![];
|
||||
if !text.is_empty() {
|
||||
assistant_parts.push(json!({
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}))
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
assistant_parts.push(json!({
|
||||
"type": "tool_use",
|
||||
"id": tool_result.call.id,
|
||||
"name": tool_result.call.name,
|
||||
"input": tool_result.call.arguments,
|
||||
}));
|
||||
user_parts.push(json!({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_result.call.id,
|
||||
"content": tool_result.output.to_string(),
|
||||
}));
|
||||
}
|
||||
vec![
|
||||
json!({
|
||||
"role": "assistant",
|
||||
"content": assistant_parts,
|
||||
}),
|
||||
json!({
|
||||
"role": "user",
|
||||
"content": user_parts,
|
||||
}),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"model": model.real_name(),
|
||||
"messages": messages,
|
||||
});
|
||||
if let Some(v) = system_message {
|
||||
body["system"] = v.into();
|
||||
}
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["max_tokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
if stream {
|
||||
body["stream"] = true.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
body["tools"] = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"input_schema": v.parameters,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = String::new();
|
||||
let mut reasoning = None;
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(list) = data["content"].as_array() {
|
||||
for item in list {
|
||||
match item["type"].as_str() {
|
||||
Some("thinking") => {
|
||||
if let Some(v) = item["thinking"].as_str() {
|
||||
reasoning = Some(v.to_string());
|
||||
}
|
||||
}
|
||||
Some("text") => {
|
||||
if let Some(v) = item["text"].as_str() {
|
||||
if !text.is_empty() {
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
text.push_str(v);
|
||||
}
|
||||
}
|
||||
Some("tool_use") => {
|
||||
if let (Some(name), Some(input), Some(id)) = (
|
||||
item["name"].as_str(),
|
||||
item.get("input"),
|
||||
item["id"].as_str(),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(reasoning) = reasoning {
|
||||
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
|
||||
let output = ChatCompletionsOutput {
|
||||
text: text.to_string(),
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["input_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["output_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
use super::openai::*;
|
||||
use super::openai_compatible::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.cohere.ai/v2";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct CohereConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl CohereClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
CohereClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
chat_completions,
|
||||
chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, embeddings),
|
||||
(prepare_rerank, generic_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &CohereClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/chat", api_base.trim_end_matches('/'));
|
||||
let mut body = openai_build_chat_completions_body(data, &self_.model);
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
if let Some(top_p) = obj.remove("top_p") {
|
||||
obj.insert("p".to_string(), top_p);
|
||||
}
|
||||
}
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/embed", api_base.trim_end_matches('/'));
|
||||
|
||||
let input_type = match data.query {
|
||||
true => "search_query",
|
||||
false => "search_document",
|
||||
};
|
||||
|
||||
let body = json!({
|
||||
"model": self_.model.real_name(),
|
||||
"texts": data.texts,
|
||||
"input_type": input_type,
|
||||
"embedding_types": ["float"],
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/rerank", api_base.trim_end_matches('/'));
|
||||
let body = generic_build_rerank_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
async fn chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
if message.data == "[DONE]" {
|
||||
return Ok(true);
|
||||
}
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(typ) = data["type"].as_str() {
|
||||
match typ {
|
||||
"content-delta" => {
|
||||
if let Some(text) = data["delta"]["message"]["content"]["text"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
"tool-plan-delta" => {
|
||||
if let Some(text) = data["delta"]["message"]["tool_plan"].as_str() {
|
||||
handler.text(text)?;
|
||||
}
|
||||
}
|
||||
"tool-call-start" => {
|
||||
if let (Some(function), Some(id)) = (
|
||||
data["delta"]["message"]["tool_calls"]["function"].as_object(),
|
||||
data["delta"]["message"]["tool_calls"]["id"].as_str(),
|
||||
) {
|
||||
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
|
||||
function_name = name.to_string();
|
||||
}
|
||||
function_id = id.to_string();
|
||||
}
|
||||
}
|
||||
"tool-call-delta" => {
|
||||
if let Some(text) =
|
||||
data["delta"]["message"]["tool_calls"]["function"]["arguments"].as_str()
|
||||
{
|
||||
function_arguments.push_str(text);
|
||||
}
|
||||
}
|
||||
"tool-call-end" => {
|
||||
if !function_name.is_empty() {
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_name.clear();
|
||||
function_arguments.clear();
|
||||
function_id.clear();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
Ok(res_body.embeddings.float)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: EmbeddingsResBodyEmbeddings,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbeddings {
|
||||
float: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text = data["message"]["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(calls) = data["message"]["tool_calls"].as_array() {
|
||||
if text.is_empty() {
|
||||
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
|
||||
text = tool_plain.to_string();
|
||||
}
|
||||
}
|
||||
for call in calls {
|
||||
if let (Some(name), Some(arguments), Some(id)) = (
|
||||
call["function"]["name"].as_str(),
|
||||
call["function"]["arguments"].as_str(),
|
||||
call["id"].as_str(),
|
||||
) {
|
||||
let arguments: Value = arguments.parse().with_context(|| {
|
||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
||||
})?;
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
arguments,
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,678 @@
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
config::{Config, GlobalConfig, Input},
|
||||
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
|
||||
render::render_stream,
|
||||
utils::*,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use indexmap::IndexMap;
|
||||
use inquire::{
|
||||
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
|
||||
};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::unbounded_channel;
|
||||
|
||||
const MODELS_YAML: &str = include_str!("../../models.yaml");
|
||||
|
||||
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
|
||||
Config::local_models_override()
|
||||
.ok()
|
||||
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
|
||||
});
|
||||
|
||||
static EMBEDDING_MODEL_RE: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap()
|
||||
});
|
||||
|
||||
static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/").unwrap());
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Client: Sync + Send {
|
||||
fn global_config(&self) -> &GlobalConfig;
|
||||
|
||||
fn extra_config(&self) -> Option<&ExtraConfig>;
|
||||
|
||||
fn patch_config(&self) -> Option<&RequestPatch>;
|
||||
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn model(&self) -> &Model;
|
||||
|
||||
fn model_mut(&mut self) -> &mut Model;
|
||||
|
||||
fn build_client(&self) -> Result<ReqwestClient> {
|
||||
let mut builder = ReqwestClient::builder();
|
||||
let extra = self.extra_config();
|
||||
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
|
||||
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
|
||||
builder = set_proxy(builder, proxy)?;
|
||||
}
|
||||
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
|
||||
builder = builder.user_agent(user_agent);
|
||||
}
|
||||
let client = builder
|
||||
.connect_timeout(Duration::from_secs(timeout))
|
||||
.build()
|
||||
.with_context(|| "Failed to build client")?;
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
return Ok(ChatCompletionsOutput::new(&content));
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = input.prepare_completion_data(self.model(), false)?;
|
||||
self.chat_completions_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to call chat-completions api")
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming(
|
||||
&self,
|
||||
input: &Input,
|
||||
handler: &mut SseHandler,
|
||||
) -> Result<()> {
|
||||
let abort_signal = handler.abort();
|
||||
let input = input.clone();
|
||||
tokio::select! {
|
||||
ret = async {
|
||||
if self.global_config().read().dry_run {
|
||||
let content = input.echo_messages();
|
||||
handler.text(&content)?;
|
||||
return Ok(());
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = input.prepare_completion_data(self.model(), true)?;
|
||||
self.chat_completions_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done();
|
||||
ret.with_context(|| "Failed to call chat-completions api")
|
||||
}
|
||||
_ = wait_abort_signal(&abort_signal) => {
|
||||
handler.done();
|
||||
Ok(())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
|
||||
let client = self.build_client()?;
|
||||
self.embeddings_inner(&client, data)
|
||||
.await
|
||||
.context("Failed to call embeddings api")
|
||||
}
|
||||
|
||||
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
|
||||
let client = self.build_client()?;
|
||||
self.rerank_inner(&client, data)
|
||||
.await
|
||||
.context("Failed to call rerank api")
|
||||
}
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput>;
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()>;
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
_client: &ReqwestClient,
|
||||
_data: &EmbeddingsData,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
async fn rerank_inner(
|
||||
&self,
|
||||
_client: &ReqwestClient,
|
||||
_data: &RerankData,
|
||||
) -> Result<RerankOutput> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
mut request_data: RequestData,
|
||||
) -> RequestBuilder {
|
||||
self.patch_request_data(&mut request_data);
|
||||
request_data.into_builder(client)
|
||||
}
|
||||
|
||||
fn patch_request_data(&self, request_data: &mut RequestData) {
|
||||
let model_type = self.model().model_type();
|
||||
if let Some(patch) = self.model().patch() {
|
||||
request_data.apply_patch(patch.clone());
|
||||
}
|
||||
|
||||
let patch_map = std::env::var(get_env_name(&format!(
|
||||
"patch_{}_{}",
|
||||
self.model().client_name(),
|
||||
model_type.api_name(),
|
||||
)))
|
||||
.ok()
|
||||
.and_then(|v| serde_json::from_str(&v).ok())
|
||||
.or_else(|| {
|
||||
self.patch_config()
|
||||
.and_then(|v| model_type.extract_patch(v))
|
||||
.cloned()
|
||||
});
|
||||
let patch_map = match patch_map {
|
||||
Some(v) => v,
|
||||
_ => return,
|
||||
};
|
||||
for (key, patch) in patch_map {
|
||||
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
|
||||
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
|
||||
if let Ok(true) = regex.is_match(self.model().name()) {
|
||||
request_data.apply_patch(patch);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self::OpenAIConfig(OpenAIConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ExtraConfig {
|
||||
pub proxy: Option<String>,
|
||||
pub connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct RequestPatch {
|
||||
pub chat_completions: Option<ApiPatch>,
|
||||
pub embeddings: Option<ApiPatch>,
|
||||
pub rerank: Option<ApiPatch>,
|
||||
}
|
||||
|
||||
pub type ApiPatch = IndexMap<String, Value>;
|
||||
|
||||
pub struct RequestData {
|
||||
pub url: String,
|
||||
pub headers: IndexMap<String, String>,
|
||||
pub body: Value,
|
||||
}
|
||||
|
||||
impl RequestData {
|
||||
pub fn new<T>(url: T, body: Value) -> Self
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
Self {
|
||||
url: url.to_string(),
|
||||
headers: Default::default(),
|
||||
body,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bearer_auth<T>(&mut self, auth: T)
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
self.headers
|
||||
.insert("authorization".into(), format!("Bearer {auth}"));
|
||||
}
|
||||
|
||||
pub fn header<K, V>(&mut self, key: K, value: V)
|
||||
where
|
||||
K: std::fmt::Display,
|
||||
V: std::fmt::Display,
|
||||
{
|
||||
self.headers.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
|
||||
pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder {
|
||||
let RequestData { url, headers, body } = self;
|
||||
debug!("Request {url} {body}");
|
||||
|
||||
let mut builder = client.post(url);
|
||||
for (key, value) in headers {
|
||||
builder = builder.header(key, value);
|
||||
}
|
||||
builder = builder.json(&body);
|
||||
builder
|
||||
}
|
||||
|
||||
pub fn apply_patch(&mut self, patch: Value) {
|
||||
if let Some(patch_url) = patch["url"].as_str() {
|
||||
self.url = patch_url.into();
|
||||
}
|
||||
if let Some(patch_body) = patch.get("body") {
|
||||
json_patch::merge(&mut self.body, patch_body)
|
||||
}
|
||||
if let Some(patch_headers) = patch["headers"].as_object() {
|
||||
for (key, value) in patch_headers {
|
||||
if let Some(value) = value.as_str() {
|
||||
self.header(key, value)
|
||||
} else if value.is_null() {
|
||||
self.headers.swap_remove(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChatCompletionsData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub functions: Option<Vec<FunctionDeclaration>>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsOutput {
|
||||
pub text: String,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
pub id: Option<String>,
|
||||
pub input_tokens: Option<u64>,
|
||||
pub output_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsOutput {
|
||||
pub fn new(text: &str) -> Self {
|
||||
Self {
|
||||
text: text.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingsData {
|
||||
pub texts: Vec<String>,
|
||||
pub query: bool,
|
||||
}
|
||||
|
||||
impl EmbeddingsData {
|
||||
pub fn new(texts: Vec<String>, query: bool) -> Self {
|
||||
Self { texts, query }
|
||||
}
|
||||
}
|
||||
|
||||
pub type EmbeddingsOutput = Vec<Vec<f32>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RerankData {
|
||||
pub query: String,
|
||||
pub documents: Vec<String>,
|
||||
pub top_n: usize,
|
||||
}
|
||||
|
||||
impl RerankData {
|
||||
pub fn new(query: String, documents: Vec<String>, top_n: usize) -> Self {
|
||||
Self {
|
||||
query,
|
||||
documents,
|
||||
top_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type RerankOutput = Vec<RerankResult>;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RerankResult {
|
||||
pub index: usize,
|
||||
pub relevance_score: f64,
|
||||
}
|
||||
|
||||
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>);
|
||||
|
||||
pub async fn create_config(
|
||||
prompts: &[PromptAction<'static>],
|
||||
client: &str,
|
||||
) -> Result<(String, Value)> {
|
||||
let mut config = json!({
|
||||
"type": client,
|
||||
});
|
||||
for (key, desc, help_message) in prompts {
|
||||
let env_name = format!("{client}_{key}").to_ascii_uppercase();
|
||||
let required = std::env::var(&env_name).is_err();
|
||||
let value = prompt_input_string(desc, required, *help_message)?;
|
||||
if !value.is_empty() {
|
||||
config[key] = value.into();
|
||||
}
|
||||
}
|
||||
let model = set_client_models_config(&mut config, client).await?;
|
||||
let clients = json!(vec![config]);
|
||||
Ok((model, clients))
|
||||
}
|
||||
|
||||
pub async fn create_openai_compatible_client_config(
|
||||
client: &str,
|
||||
) -> Result<Option<(String, Value)>> {
|
||||
let api_base = OPENAI_COMPATIBLE_PROVIDERS
|
||||
.into_iter()
|
||||
.find(|(name, _)| client == *name)
|
||||
.map(|(_, api_base)| api_base)
|
||||
.unwrap_or("http(s)://{API_ADDR}/v1");
|
||||
|
||||
let name = if client == OpenAICompatibleClient::NAME {
|
||||
let value = prompt_input_string("Provider Name", true, None)?;
|
||||
value.replace(' ', "-")
|
||||
} else {
|
||||
client.to_string()
|
||||
};
|
||||
|
||||
let mut config = json!({
|
||||
"type": OpenAICompatibleClient::NAME,
|
||||
"name": &name,
|
||||
});
|
||||
|
||||
let api_base = if api_base.contains('{') {
|
||||
prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))?
|
||||
} else {
|
||||
api_base.to_string()
|
||||
};
|
||||
config["api_base"] = api_base.into();
|
||||
|
||||
let api_key = prompt_input_string("API Key", false, None)?;
|
||||
if !api_key.is_empty() {
|
||||
config["api_key"] = api_key.into();
|
||||
}
|
||||
|
||||
let model = set_client_models_config(&mut config, &name).await?;
|
||||
let clients = json!(vec![config]);
|
||||
Ok(Some((model, clients)))
|
||||
}
|
||||
|
||||
pub async fn call_chat_completions(
|
||||
input: &Input,
|
||||
print: bool,
|
||||
extract_code: bool,
|
||||
client: &dyn Client,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let ret = abortable_run_with_spinner(
|
||||
client.chat_completions(input.clone()),
|
||||
"Generating",
|
||||
abort_signal,
|
||||
)
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(ret) => {
|
||||
let ChatCompletionsOutput {
|
||||
mut text,
|
||||
tool_calls,
|
||||
..
|
||||
} = ret;
|
||||
if !text.is_empty() {
|
||||
if extract_code {
|
||||
text = extract_code_block(&strip_think_tag(&text)).to_string();
|
||||
}
|
||||
if print {
|
||||
client.global_config().read().print_markdown(&text)?;
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_chat_completions_streaming(
|
||||
input: &Input,
|
||||
client: &dyn Client,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<(String, Vec<ToolResult>)> {
|
||||
let (tx, rx) = unbounded_channel();
|
||||
let mut handler = SseHandler::new(tx, abort_signal.clone());
|
||||
|
||||
let (send_ret, render_ret) = tokio::join!(
|
||||
client.chat_completions_streaming(input, &mut handler),
|
||||
render_stream(rx, client.global_config(), abort_signal.clone()),
|
||||
);
|
||||
|
||||
if handler.abort().aborted() {
|
||||
bail!("Aborted.");
|
||||
}
|
||||
|
||||
render_ret?;
|
||||
|
||||
let (text, tool_calls) = handler.take();
|
||||
match send_ret {
|
||||
Ok(_) => {
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
println!();
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
if !text.is_empty() {
|
||||
println!();
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
bail!("The client doesn't support embeddings api")
|
||||
}
|
||||
|
||||
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
||||
bail!("The client doesn't support rerank api")
|
||||
}
|
||||
|
||||
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
|
||||
if (200..300).contains(&status) {
|
||||
return Ok(());
|
||||
}
|
||||
debug!("Invalid response, status: {status}, data: {data}");
|
||||
if let Some(error) = data["error"].as_object() {
|
||||
if let (Some(typ), Some(message)) = (
|
||||
json_str_from_map(error, "type"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (type: {typ})");
|
||||
} else if let (Some(typ), Some(message)) = (
|
||||
json_str_from_map(error, "code"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (code: {typ})");
|
||||
}
|
||||
} else if let Some(error) = data["errors"][0].as_object() {
|
||||
if let (Some(code), Some(message)) = (
|
||||
error.get("code").and_then(|v| v.as_u64()),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (status: {code})")
|
||||
}
|
||||
} else if let Some(error) = data[0]["error"].as_object() {
|
||||
if let (Some(status), Some(message)) = (
|
||||
json_str_from_map(error, "status"),
|
||||
json_str_from_map(error, "message"),
|
||||
) {
|
||||
bail!("{message} (status: {status})")
|
||||
}
|
||||
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
|
||||
{
|
||||
bail!("{detail} (status: {status})");
|
||||
} else if let Some(error) = data["error"].as_str() {
|
||||
bail!("{error}");
|
||||
} else if let Some(message) = data["message"].as_str() {
|
||||
bail!("{message}");
|
||||
}
|
||||
bail!("Invalid response data: {data} (status: {status})");
|
||||
}
|
||||
|
||||
pub fn json_str_from_map<'a>(
|
||||
map: &'a serde_json::Map<String, Value>,
|
||||
field_name: &str,
|
||||
) -> Option<&'a str> {
|
||||
map.get(field_name).and_then(|v| v.as_str())
|
||||
}
|
||||
|
||||
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
|
||||
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
|
||||
let models: Vec<String> = provider
|
||||
.models
|
||||
.iter()
|
||||
.filter(|v| v.model_type == "chat")
|
||||
.map(|v| v.name.clone())
|
||||
.collect();
|
||||
let model_name = select_model(models)?;
|
||||
return Ok(format!("{client}:{model_name}"));
|
||||
}
|
||||
let mut model_names = vec![];
|
||||
if let (Some(true), Some(api_base), api_key) = (
|
||||
client_config["type"]
|
||||
.as_str()
|
||||
.map(|v| v == OpenAICompatibleClient::NAME),
|
||||
client_config["api_base"].as_str(),
|
||||
client_config["api_key"]
|
||||
.as_str()
|
||||
.map(|v| v.to_string())
|
||||
.or_else(|| {
|
||||
let env_name = format!("{client}_api_key").to_ascii_uppercase();
|
||||
std::env::var(&env_name).ok()
|
||||
}),
|
||||
) {
|
||||
match abortable_run_with_spinner(
|
||||
fetch_models(api_base, api_key.as_deref()),
|
||||
"Fetching models",
|
||||
create_abort_signal(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(fetched_models) => {
|
||||
model_names = MultiSelect::new("LLMs to include (required):", fetched_models)
|
||||
.with_validator(|list: &[ListOption<&String>]| {
|
||||
if list.is_empty() {
|
||||
Ok(Validation::Invalid(
|
||||
"At least one item must be selected".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("✗ Fetch models failed: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if model_names.is_empty() {
|
||||
model_names = prompt_input_string(
|
||||
"LLMs to add",
|
||||
true,
|
||||
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
|
||||
)?
|
||||
.split(',')
|
||||
.filter_map(|v| {
|
||||
let v = v.trim();
|
||||
if v.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(v.to_string())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
if model_names.is_empty() {
|
||||
bail!("No models");
|
||||
}
|
||||
let models: Vec<Value> = model_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let l = v.to_lowercase();
|
||||
if l.contains("rank") {
|
||||
json!({
|
||||
"name": v,
|
||||
"type": "reranker",
|
||||
})
|
||||
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
|
||||
json!({
|
||||
"name": v,
|
||||
"type": "embedding",
|
||||
"default_chunk_size": 1000,
|
||||
"max_batch_size": 100
|
||||
})
|
||||
} else if v.contains("vision") {
|
||||
json!({
|
||||
"name": v,
|
||||
"supports_vision": true
|
||||
})
|
||||
} else {
|
||||
json!({
|
||||
"name": v,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
client_config["models"] = models.into();
|
||||
let model_name = select_model(model_names)?;
|
||||
Ok(format!("{client}:{model_name}"))
|
||||
}
|
||||
|
||||
fn select_model(model_names: Vec<String>) -> Result<String> {
|
||||
if model_names.is_empty() {
|
||||
bail!("No models");
|
||||
}
|
||||
let model = if model_names.len() == 1 {
|
||||
model_names[0].clone()
|
||||
} else {
|
||||
Select::new("Default Model (required):", model_names).prompt()?
|
||||
};
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result<String> {
|
||||
let desc = if required {
|
||||
format!("{desc} (required):")
|
||||
} else {
|
||||
format!("{desc} (optional):")
|
||||
};
|
||||
let mut text = Text::new(&desc);
|
||||
if required {
|
||||
text = text.with_validator(required!("This field is required"))
|
||||
}
|
||||
if let Some(help_message) = help_message {
|
||||
text = text.with_help_message(help_message);
|
||||
}
|
||||
let text = text.prompt()?;
|
||||
Ok(text)
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
use super::vertexai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct GeminiConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl GeminiClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
GeminiClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
gemini_chat_completions,
|
||||
gemini_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &GeminiClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let func = match data.stream {
|
||||
true => "streamGenerateContent",
|
||||
false => "generateContent",
|
||||
};
|
||||
|
||||
let url = format!(
|
||||
"{}/models/{}:{}",
|
||||
api_base.trim_end_matches('/'),
|
||||
self_.model.real_name(),
|
||||
func
|
||||
);
|
||||
|
||||
let body = gemini_build_chat_completions_body(data, &self_.model)?;
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.header("x-goog-api-key", api_key);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!(
|
||||
"{}/models/{}:batchEmbedContents?key={}",
|
||||
api_base.trim_end_matches('/'),
|
||||
self_.model.real_name(),
|
||||
api_key
|
||||
);
|
||||
|
||||
let model_id = format!("models/{}", self_.model.real_name());
|
||||
|
||||
let requests: Vec<_> = data
|
||||
.texts
|
||||
.iter()
|
||||
.map(|text| {
|
||||
json!({
|
||||
"model": model_id,
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": text
|
||||
}
|
||||
]
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let body = json!({
|
||||
"requests": requests,
|
||||
});
|
||||
|
||||
let request_data = RequestData::new(url, body);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.values)
|
||||
.collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
embeddings: Vec<EmbeddingsResBodyEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbedding {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
#[macro_export]
|
||||
macro_rules! register_client {
|
||||
(
|
||||
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
|
||||
) => {
|
||||
$(
|
||||
mod $module;
|
||||
)+
|
||||
$(
|
||||
use self::$module::$config;
|
||||
)+
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientConfig {
|
||||
$(
|
||||
#[serde(rename = $name)]
|
||||
$config($config),
|
||||
)+
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
$(
|
||||
#[derive(Debug)]
|
||||
pub struct $client {
|
||||
global_config: $crate::config::GlobalConfig,
|
||||
config: $config,
|
||||
model: $crate::client::Model,
|
||||
}
|
||||
|
||||
impl $client {
|
||||
pub const NAME: &'static str = $name;
|
||||
|
||||
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
|
||||
let config = global_config.read().clients.iter().find_map(|client_config| {
|
||||
if let ClientConfig::$config(c) = client_config {
|
||||
if Self::name(c) == model.client_name() {
|
||||
return Some(c.clone())
|
||||
}
|
||||
}
|
||||
None
|
||||
})?;
|
||||
|
||||
Some(Box::new(Self {
|
||||
global_config: global_config.clone(),
|
||||
config,
|
||||
model: model.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn list_models(local_config: &$config) -> Vec<Model> {
|
||||
let client_name = Self::name(local_config);
|
||||
if local_config.models.is_empty() {
|
||||
if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| {
|
||||
v.provider == $name ||
|
||||
($name == OpenAICompatibleClient::NAME
|
||||
&& local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default())
|
||||
}) {
|
||||
return Model::from_config(client_name, &v.models);
|
||||
}
|
||||
vec![]
|
||||
} else {
|
||||
Model::from_config(client_name, &local_config.models)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(local_config: &$config) -> &str {
|
||||
local_config.name.as_deref().unwrap_or(Self::NAME)
|
||||
}
|
||||
}
|
||||
|
||||
)+
|
||||
|
||||
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
|
||||
let model = model.unwrap_or_else(|| config.read().model.clone());
|
||||
None
|
||||
$(.or_else(|| $client::init(config, &model)))+
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid model '{}'", model.id())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_client_types() -> Vec<&'static str> {
|
||||
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
|
||||
client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name));
|
||||
client_types
|
||||
}
|
||||
|
||||
pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
|
||||
$(
|
||||
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
|
||||
return create_config(&$client::PROMPTS, $client::NAME).await
|
||||
}
|
||||
)+
|
||||
if let Some(ret) = create_openai_compatible_client_config(client).await? {
|
||||
return Ok(ret);
|
||||
}
|
||||
anyhow::bail!("Unknown client '{}'", client)
|
||||
}
|
||||
|
||||
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
|
||||
let names = ALL_CLIENT_NAMES.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.flat_map(|v| match v {
|
||||
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
|
||||
ClientConfig::Unknown => vec![],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
names.iter().collect()
|
||||
}
|
||||
|
||||
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
|
||||
let models = ALL_MODELS.get_or_init(|| {
|
||||
config
|
||||
.clients
|
||||
.iter()
|
||||
.flat_map(|v| match v {
|
||||
$(ClientConfig::$config(c) => $client::list_models(c),)+
|
||||
ClientConfig::Unknown => vec![],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
models.iter().collect()
|
||||
}
|
||||
|
||||
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
|
||||
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! client_common_fns {
|
||||
() => {
|
||||
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
||||
&self.global_config
|
||||
}
|
||||
|
||||
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
||||
self.config.extra.as_ref()
|
||||
}
|
||||
|
||||
fn patch_config(&self) -> Option<&$crate::client::RequestPatch> {
|
||||
self.config.patch.as_ref()
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
Self::name(&self.config)
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn model_mut(&mut self) -> &mut Model {
|
||||
&mut self.model
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_client_trait {
|
||||
(
|
||||
$client:ident,
|
||||
($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path),
|
||||
($prepare_embeddings:path, $embeddings:path),
|
||||
($prepare_rerank:path, $rerank:path),
|
||||
) => {
|
||||
#[async_trait::async_trait]
|
||||
impl $crate::client::Client for $crate::client::$client {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: $crate::client::ChatCompletionsData,
|
||||
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
|
||||
let request_data = $prepare_chat_completions(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$chat_completions(builder, self.model()).await
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
handler: &mut $crate::client::SseHandler,
|
||||
data: $crate::client::ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
let request_data = $prepare_chat_completions(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$chat_completions_streaming(builder, handler, self.model()).await
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: &$crate::client::EmbeddingsData,
|
||||
) -> Result<$crate::client::EmbeddingsOutput> {
|
||||
let request_data = $prepare_embeddings(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$embeddings(builder, self.model()).await
|
||||
}
|
||||
|
||||
async fn rerank_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: &$crate::client::RerankData,
|
||||
) -> Result<$crate::client::RerankOutput> {
|
||||
let request_data = $prepare_rerank(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
$rerank(builder, self.model()).await
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! config_get_fn {
|
||||
($field_name:ident, $fn_name:ident) => {
|
||||
fn $fn_name(&self) -> anyhow::Result<String> {
|
||||
let env_prefix = Self::name(&self.config);
|
||||
let env_name =
|
||||
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
|
||||
std::env::var(&env_name)
|
||||
.ok()
|
||||
.or_else(|| self.config.$field_name.clone())
|
||||
.ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! unsupported_model {
|
||||
($name:expr) => {
|
||||
anyhow::bail!("Unsupported model '{}'", $name)
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
use super::Model;
|
||||
|
||||
use crate::{function::ToolResult, multiline_text, utils::dimmed_text};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
impl Default for Message {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Text(String::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: MessageRole, content: MessageContent) -> Self {
|
||||
Self { role, content }
|
||||
}
|
||||
|
||||
pub fn merge_system(&mut self, system: MessageContent) {
|
||||
match (&mut self.content, system) {
|
||||
(MessageContent::Text(text), MessageContent::Text(system_text)) => {
|
||||
self.content = MessageContent::Array(vec![
|
||||
MessageContentPart::Text { text: system_text },
|
||||
MessageContentPart::Text {
|
||||
text: text.to_string(),
|
||||
},
|
||||
])
|
||||
}
|
||||
(MessageContent::Array(list), MessageContent::Text(system_text)) => {
|
||||
list.insert(0, MessageContentPart::Text { text: system_text })
|
||||
}
|
||||
(MessageContent::Text(text), MessageContent::Array(mut system_list)) => {
|
||||
system_list.push(MessageContentPart::Text {
|
||||
text: text.to_string(),
|
||||
});
|
||||
self.content = MessageContent::Array(system_list);
|
||||
}
|
||||
(MessageContent::Array(list), MessageContent::Array(mut system_list)) => {
|
||||
system_list.append(list);
|
||||
self.content = MessageContent::Array(system_list);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessageRole {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl MessageRole {
|
||||
pub fn is_system(&self) -> bool {
|
||||
matches!(self, MessageRole::System)
|
||||
}
|
||||
|
||||
pub fn is_user(&self) -> bool {
|
||||
matches!(self, MessageRole::User)
|
||||
}
|
||||
|
||||
pub fn is_assistant(&self) -> bool {
|
||||
matches!(self, MessageRole::Assistant)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
Array(Vec<MessageContentPart>),
|
||||
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
|
||||
ToolCalls(MessageContentToolCalls),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn render_input(
|
||||
&self,
|
||||
resolve_url_fn: impl Fn(&str) -> String,
|
||||
agent_info: &Option<(String, Vec<String>)>,
|
||||
) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => multiline_text(text),
|
||||
MessageContent::Array(list) => {
|
||||
let (mut concated_text, mut files) = (String::new(), vec![]);
|
||||
for item in list {
|
||||
match item {
|
||||
MessageContentPart::Text { text } => {
|
||||
concated_text = format!("{concated_text} {text}")
|
||||
}
|
||||
MessageContentPart::ImageUrl { image_url } => {
|
||||
files.push(resolve_url_fn(&image_url.url))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !concated_text.is_empty() {
|
||||
concated_text = format!(" -- {}", multiline_text(&concated_text))
|
||||
}
|
||||
format!(".file {}{}", files.join(" "), concated_text)
|
||||
}
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
let mut lines = vec![];
|
||||
if !text.is_empty() {
|
||||
lines.push(text.clone())
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
let mut parts = vec!["Call".to_string()];
|
||||
if let Some((agent_name, functions)) = agent_info {
|
||||
if functions.contains(&tool_result.call.name) {
|
||||
parts.push(agent_name.clone())
|
||||
}
|
||||
}
|
||||
parts.push(tool_result.call.name.clone());
|
||||
parts.push(tool_result.call.arguments.to_string());
|
||||
lines.push(dimmed_text(&parts.join(" ")));
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
|
||||
match self {
|
||||
MessageContent::Text(text) => *text = replace_fn(text),
|
||||
MessageContent::Array(list) => {
|
||||
if list.is_empty() {
|
||||
list.push(MessageContentPart::Text {
|
||||
text: replace_fn(""),
|
||||
})
|
||||
} else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) {
|
||||
*text = replace_fn(text)
|
||||
}
|
||||
}
|
||||
MessageContent::ToolCalls(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.to_string(),
|
||||
MessageContent::Array(list) => {
|
||||
let mut parts = vec![];
|
||||
for item in list {
|
||||
if let MessageContentPart::Text { text } = item {
|
||||
parts.push(text.clone())
|
||||
}
|
||||
}
|
||||
parts.join("\n\n")
|
||||
}
|
||||
MessageContent::ToolCalls(_) => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageContentPart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContentToolCalls {
|
||||
pub tool_results: Vec<ToolResult>,
|
||||
pub text: String,
|
||||
pub sequence: bool,
|
||||
}
|
||||
|
||||
impl MessageContentToolCalls {
|
||||
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
|
||||
Self {
|
||||
tool_results,
|
||||
text,
|
||||
sequence: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge(&mut self, tool_results: Vec<ToolResult>, _text: String) {
|
||||
self.tool_results.extend(tool_results);
|
||||
self.text.clear();
|
||||
self.sequence = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn patch_messages(messages: &mut Vec<Message>, model: &Model) {
|
||||
if messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
if let Some(prefix) = model.system_prompt_prefix() {
|
||||
if messages[0].role.is_system() {
|
||||
messages[0].merge_system(MessageContent::Text(prefix.to_string()));
|
||||
} else {
|
||||
messages.insert(
|
||||
0,
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: MessageContent::Text(prefix.to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
if model.no_system_message() && messages[0].role.is_system() {
|
||||
let system_message = messages.remove(0);
|
||||
if let (Some(message), system) = (messages.get_mut(0), system_message.content) {
|
||||
message.merge_system(system);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_system_message(messages: &mut Vec<Message>) -> Option<String> {
|
||||
if messages[0].role.is_system() {
|
||||
let system_message = messages.remove(0);
|
||||
return Some(system_message.content.to_text());
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
mod access_token;
|
||||
mod common;
|
||||
mod message;
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
mod model;
|
||||
mod stream;
|
||||
|
||||
pub use crate::function::ToolCall;
|
||||
pub use common::*;
|
||||
pub use message::*;
|
||||
pub use model::*;
|
||||
pub use stream::*;
|
||||
|
||||
register_client!(
|
||||
(openai, "openai", OpenAIConfig, OpenAIClient),
|
||||
(
|
||||
openai_compatible,
|
||||
"openai-compatible",
|
||||
OpenAICompatibleConfig,
|
||||
OpenAICompatibleClient
|
||||
),
|
||||
(gemini, "gemini", GeminiConfig, GeminiClient),
|
||||
(claude, "claude", ClaudeConfig, ClaudeClient),
|
||||
(cohere, "cohere", CohereConfig, CohereClient),
|
||||
(
|
||||
azure_openai,
|
||||
"azure-openai",
|
||||
AzureOpenAIConfig,
|
||||
AzureOpenAIClient
|
||||
),
|
||||
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
|
||||
(bedrock, "bedrock", BedrockConfig, BedrockClient),
|
||||
);
|
||||
|
||||
pub const OPENAI_COMPATIBLE_PROVIDERS: [(&str, &str); 18] = [
|
||||
("ai21", "https://api.ai21.com/studio/v1"),
|
||||
(
|
||||
"cloudflare",
|
||||
"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1",
|
||||
),
|
||||
("deepinfra", "https://api.deepinfra.com/v1/openai"),
|
||||
("deepseek", "https://api.deepseek.com"),
|
||||
("ernie", "https://qianfan.baidubce.com/v2"),
|
||||
("github", "https://models.inference.ai.azure.com"),
|
||||
("groq", "https://api.groq.com/openai/v1"),
|
||||
("hunyuan", "https://api.hunyuan.cloud.tencent.com/v1"),
|
||||
("minimax", "https://api.minimax.chat/v1"),
|
||||
("mistral", "https://api.mistral.ai/v1"),
|
||||
("moonshot", "https://api.moonshot.cn/v1"),
|
||||
("openrouter", "https://openrouter.ai/api/v1"),
|
||||
("perplexity", "https://api.perplexity.ai"),
|
||||
(
|
||||
"qianwen",
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
),
|
||||
("xai", "https://api.x.ai/v1"),
|
||||
("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
|
||||
// RAG-dedicated
|
||||
("jina", "https://api.jina.ai/v1"),
|
||||
("voyageai", "https://api.voyageai.com/v1"),
|
||||
];
|
||||
@@ -0,0 +1,407 @@
|
||||
use super::{
|
||||
list_all_models, list_client_names,
|
||||
message::{Message, MessageContent, MessageContentPart},
|
||||
ApiPatch, MessageContentToolCalls, RequestPatch,
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::utils::{estimate_token_length, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::fmt::Display;
|
||||
|
||||
const PER_MESSAGES_TOKENS: usize = 5;
|
||||
const BASIS_TOKENS: usize = 2;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
client_name: String,
|
||||
data: ModelData,
|
||||
}
|
||||
|
||||
impl Default for Model {
|
||||
fn default() -> Self {
|
||||
Model::new("", "")
|
||||
}
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(client_name: &str, name: &str) -> Self {
|
||||
Self {
|
||||
client_name: client_name.into(),
|
||||
data: ModelData::new(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
|
||||
models
|
||||
.iter()
|
||||
.map(|v| Model {
|
||||
client_name: client_name.to_string(),
|
||||
data: v.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
|
||||
let models = list_all_models(config);
|
||||
let (client_name, model_name) = match model_id.split_once(':') {
|
||||
Some((client_name, model_name)) => {
|
||||
if model_name.is_empty() {
|
||||
(client_name, None)
|
||||
} else {
|
||||
(client_name, Some(model_name))
|
||||
}
|
||||
}
|
||||
None => (model_id, None),
|
||||
};
|
||||
match model_name {
|
||||
Some(model_name) => {
|
||||
if let Some(model) = models.iter().find(|v| v.id() == model_id) {
|
||||
if model.model_type() == model_type {
|
||||
return Ok((*model).clone());
|
||||
} else {
|
||||
bail!("Model '{model_id}' is not a {model_type} model")
|
||||
}
|
||||
}
|
||||
if list_client_names(config)
|
||||
.into_iter()
|
||||
.any(|v| *v == client_name)
|
||||
&& model_type.can_create_from_name()
|
||||
{
|
||||
let mut new_model = Self::new(client_name, model_name);
|
||||
new_model.data.model_type = model_type.to_string();
|
||||
return Ok(new_model);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if let Some(found) = models
|
||||
.iter()
|
||||
.find(|v| v.client_name == client_name && v.model_type() == model_type)
|
||||
{
|
||||
return Ok((*found).clone());
|
||||
}
|
||||
}
|
||||
};
|
||||
bail!("Unknown {model_type} model '{model_id}'")
|
||||
}
|
||||
|
||||
pub fn id(&self) -> String {
|
||||
if self.data.name.is_empty() {
|
||||
self.client_name.to_string()
|
||||
} else {
|
||||
format!("{}:{}", self.client_name, self.data.name)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_name(&self) -> &str {
|
||||
&self.client_name
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.data.name
|
||||
}
|
||||
|
||||
pub fn real_name(&self) -> &str {
|
||||
self.data.real_name.as_deref().unwrap_or(&self.data.name)
|
||||
}
|
||||
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
if self.data.model_type.starts_with("embed") {
|
||||
ModelType::Embedding
|
||||
} else if self.data.model_type.starts_with("rerank") {
|
||||
ModelType::Reranker
|
||||
} else {
|
||||
ModelType::Chat
|
||||
}
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &ModelData {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn data_mut(&mut self) -> &mut ModelData {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
pub fn description(&self) -> String {
|
||||
match self.model_type() {
|
||||
ModelType::Chat => {
|
||||
let ModelData {
|
||||
max_input_tokens,
|
||||
max_output_tokens,
|
||||
input_price,
|
||||
output_price,
|
||||
supports_vision,
|
||||
supports_function_calling,
|
||||
..
|
||||
} = &self.data;
|
||||
let max_input_tokens = stringify_option_value(max_input_tokens);
|
||||
let max_output_tokens = stringify_option_value(max_output_tokens);
|
||||
let input_price = stringify_option_value(input_price);
|
||||
let output_price = stringify_option_value(output_price);
|
||||
let mut capabilities = vec![];
|
||||
if *supports_vision {
|
||||
capabilities.push('👁');
|
||||
};
|
||||
if *supports_function_calling {
|
||||
capabilities.push('⚒');
|
||||
};
|
||||
let capabilities: String = capabilities
|
||||
.into_iter()
|
||||
.map(|v| format!("{v} "))
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
format!(
|
||||
"{max_input_tokens:>8} / {max_output_tokens:>8} | {input_price:>6} / {output_price:>6} {capabilities:>6}"
|
||||
)
|
||||
}
|
||||
ModelType::Embedding => {
|
||||
let ModelData {
|
||||
input_price,
|
||||
max_tokens_per_chunk,
|
||||
max_batch_size,
|
||||
..
|
||||
} = &self.data;
|
||||
let max_tokens = stringify_option_value(max_tokens_per_chunk);
|
||||
let max_batch = stringify_option_value(max_batch_size);
|
||||
let price = stringify_option_value(input_price);
|
||||
format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}")
|
||||
}
|
||||
ModelType::Reranker => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn patch(&self) -> Option<&Value> {
|
||||
self.data.patch.as_ref()
|
||||
}
|
||||
|
||||
pub fn max_input_tokens(&self) -> Option<usize> {
|
||||
self.data.max_input_tokens
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> Option<isize> {
|
||||
self.data.max_output_tokens
|
||||
}
|
||||
|
||||
pub fn no_stream(&self) -> bool {
|
||||
self.data.no_stream
|
||||
}
|
||||
|
||||
pub fn no_system_message(&self) -> bool {
|
||||
self.data.no_system_message
|
||||
}
|
||||
|
||||
pub fn system_prompt_prefix(&self) -> Option<&str> {
|
||||
self.data.system_prompt_prefix.as_deref()
|
||||
}
|
||||
|
||||
pub fn max_tokens_per_chunk(&self) -> Option<usize> {
|
||||
self.data.max_tokens_per_chunk
|
||||
}
|
||||
|
||||
pub fn default_chunk_size(&self) -> usize {
|
||||
self.data.default_chunk_size.unwrap_or(1000)
|
||||
}
|
||||
|
||||
pub fn max_batch_size(&self) -> Option<usize> {
|
||||
self.data.max_batch_size
|
||||
}
|
||||
|
||||
pub fn max_tokens_param(&self) -> Option<isize> {
|
||||
if self.data.require_max_tokens {
|
||||
self.data.max_output_tokens
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_max_tokens(
|
||||
&mut self,
|
||||
max_output_tokens: Option<isize>,
|
||||
require_max_tokens: bool,
|
||||
) -> &mut Self {
|
||||
match max_output_tokens {
|
||||
None | Some(0) => self.data.max_output_tokens = None,
|
||||
_ => self.data.max_output_tokens = max_output_tokens,
|
||||
}
|
||||
self.data.require_max_tokens = require_max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
|
||||
let messages_len = messages.len();
|
||||
messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| match &v.content {
|
||||
MessageContent::Text(text) => {
|
||||
if v.role.is_assistant() && i != messages_len - 1 {
|
||||
estimate_token_length(&strip_think_tag(text))
|
||||
} else {
|
||||
estimate_token_length(text)
|
||||
}
|
||||
}
|
||||
MessageContent::Array(list) => list
|
||||
.iter()
|
||||
.map(|v| match v {
|
||||
MessageContentPart::Text { text } => estimate_token_length(text),
|
||||
MessageContentPart::ImageUrl { .. } => 0,
|
||||
})
|
||||
.sum(),
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results, text, ..
|
||||
}) => {
|
||||
estimate_token_length(text)
|
||||
+ tool_results
|
||||
.iter()
|
||||
.map(|v| {
|
||||
serde_json::to_string(v)
|
||||
.map(|v| estimate_token_length(&v))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.sum::<usize>()
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn total_tokens(&self, messages: &[Message]) -> usize {
|
||||
if messages.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let num_messages = messages.len();
|
||||
let message_tokens = self.messages_tokens(messages);
|
||||
if messages[num_messages - 1].role.is_user() {
|
||||
num_messages * PER_MESSAGES_TOKENS + message_tokens
|
||||
} else {
|
||||
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
|
||||
}
|
||||
}
|
||||
|
||||
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
|
||||
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
|
||||
if let Some(max_input_tokens) = self.data.max_input_tokens {
|
||||
if total_tokens >= max_input_tokens {
|
||||
bail!("Exceed max_input_tokens limit")
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub name: String,
|
||||
#[serde(default = "default_model_type", rename = "type")]
|
||||
pub model_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub real_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_input_tokens: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_price: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_price: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub patch: Option<Value>,
|
||||
|
||||
// chat-only properties
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<isize>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub require_max_tokens: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub supports_vision: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub supports_function_calling: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
no_stream: bool,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
no_system_message: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system_prompt_prefix: Option<String>,
|
||||
|
||||
// embedding-only properties
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens_per_chunk: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default_chunk_size: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl ModelData {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
model_type: default_model_type(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderModels {
|
||||
pub provider: String,
|
||||
pub models: Vec<ModelData>,
|
||||
}
|
||||
|
||||
fn default_model_type() -> String {
|
||||
"chat".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum ModelType {
|
||||
Chat,
|
||||
Embedding,
|
||||
Reranker,
|
||||
}
|
||||
|
||||
impl Display for ModelType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ModelType::Chat => write!(f, "chat"),
|
||||
ModelType::Embedding => write!(f, "embedding"),
|
||||
ModelType::Reranker => write!(f, "reranker"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
pub fn can_create_from_name(self) -> bool {
|
||||
match self {
|
||||
ModelType::Chat => true,
|
||||
ModelType::Embedding => false,
|
||||
ModelType::Reranker => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_name(self) -> &'static str {
|
||||
match self {
|
||||
ModelType::Chat => "chat_completions",
|
||||
ModelType::Embedding => "embeddings",
|
||||
ModelType::Reranker => "rerank",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> {
|
||||
match self {
|
||||
ModelType::Chat => patch.chat_completions.as_ref(),
|
||||
ModelType::Embedding => patch.embeddings.as_ref(),
|
||||
ModelType::Reranker => patch.rerank.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stringify_option_value<T>(value: &Option<T>) -> String
|
||||
where
|
||||
T: Display,
|
||||
{
|
||||
match value {
|
||||
Some(value) => value.to_string(),
|
||||
None => "-".to_string(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
use super::*;
|
||||
|
||||
use crate::utils::strip_think_tag;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const API_BASE: &str = "https://api.openai.com/v1";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct OpenAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub organization_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
OpenAIClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(noop_prepare_rerank, noop_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &OpenAIClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
if let Some(organization_id) = &self_.config.organization_id {
|
||||
request_data.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key()?;
|
||||
let api_base = self_
|
||||
.get_api_base()
|
||||
.unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
let url = format!("{api_base}/embeddings");
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(api_key);
|
||||
if let Some(organization_id) = &self_.config.organization_id {
|
||||
request_data.header("OpenAI-Organization", organization_id);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn openai_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
|
||||
debug!("non-stream-data: {data}");
|
||||
openai_extract_chat_completions(&data)
|
||||
}
|
||||
|
||||
pub async fn openai_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let mut call_id = String::new();
|
||||
let mut function_name = String::new();
|
||||
let mut function_arguments = String::new();
|
||||
let mut function_id = String::new();
|
||||
let mut reasoning_state = 0;
|
||||
let handle = |message: SseMessage| -> Result<bool> {
|
||||
if message.data == "[DONE]" {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
normalize_function_id(&function_id),
|
||||
))?;
|
||||
}
|
||||
return Ok(true);
|
||||
}
|
||||
let data: Value = serde_json::from_str(&message.data)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(text) = data["choices"][0]["delta"]["content"]
|
||||
.as_str()
|
||||
.filter(|v| !v.is_empty())
|
||||
{
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let Some(text) = data["choices"][0]["delta"]["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| data["choices"][0]["delta"]["reasoning"].as_str())
|
||||
.filter(|v| !v.is_empty())
|
||||
{
|
||||
if reasoning_state == 0 {
|
||||
handler.text("<think>\n")?;
|
||||
reasoning_state = 1;
|
||||
}
|
||||
handler.text(text)?;
|
||||
}
|
||||
if let (Some(function), index, id) = (
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(),
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(),
|
||||
data["choices"][0]["delta"]["tool_calls"][0]["id"]
|
||||
.as_str()
|
||||
.filter(|v| !v.is_empty()),
|
||||
) {
|
||||
if reasoning_state == 1 {
|
||||
handler.text("\n</think>\n\n")?;
|
||||
reasoning_state = 0;
|
||||
}
|
||||
let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default());
|
||||
if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
normalize_function_id(&function_id),
|
||||
))?;
|
||||
}
|
||||
function_name.clear();
|
||||
function_arguments.clear();
|
||||
function_id.clear();
|
||||
call_id = maybe_call_id;
|
||||
}
|
||||
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
|
||||
if name.starts_with(&function_name) {
|
||||
function_name = name.to_string();
|
||||
} else {
|
||||
function_name.push_str(name);
|
||||
}
|
||||
}
|
||||
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
|
||||
function_arguments.push_str(arguments);
|
||||
}
|
||||
if let Some(id) = id {
|
||||
function_id = id.to_string();
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
};
|
||||
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub async fn openai_embeddings(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
data: Vec<EmbeddingsResBodyEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
|
||||
let ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
} = data;
|
||||
|
||||
let messages_len = messages.len();
|
||||
let messages: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, message)| {
|
||||
let Message { role, content } = message;
|
||||
match content {
|
||||
MessageContent::ToolCalls(MessageContentToolCalls {
|
||||
tool_results,
|
||||
text: _,
|
||||
sequence,
|
||||
}) => {
|
||||
if !sequence {
|
||||
let tool_calls: Vec<_> = tool_results
|
||||
.iter()
|
||||
.map(|tool_result| {
|
||||
json!({
|
||||
"id": tool_result.call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_result.call.name,
|
||||
"arguments": tool_result.call.arguments.to_string(),
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut messages = vec![
|
||||
json!({ "role": MessageRole::Assistant, "tool_calls": tool_calls }),
|
||||
];
|
||||
for tool_result in tool_results {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"content": tool_result.output.to_string(),
|
||||
"tool_call_id": tool_result.call.id,
|
||||
}));
|
||||
}
|
||||
messages
|
||||
} else {
|
||||
tool_results.into_iter().flat_map(|tool_result| {
|
||||
vec![
|
||||
json!({
|
||||
"role": MessageRole::Assistant,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_result.call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_result.call.name,
|
||||
"arguments": tool_result.call.arguments.to_string(),
|
||||
},
|
||||
}
|
||||
]
|
||||
}),
|
||||
json!({
|
||||
"role": "tool",
|
||||
"content": tool_result.output.to_string(),
|
||||
"tool_call_id": tool_result.call.id,
|
||||
})
|
||||
]
|
||||
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
|
||||
vec![json!({ "role": role, "content": strip_think_tag(&text) }
|
||||
)]
|
||||
}
|
||||
_ => vec![json!({ "role": role, "content": content })],
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut body = json!({
|
||||
"model": &model.real_name(),
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
if model
|
||||
.patch()
|
||||
.and_then(|v| v.get("body").and_then(|v| v.get("max_tokens")))
|
||||
== Some(&Value::Null)
|
||||
{
|
||||
body["max_completion_tokens"] = v.into();
|
||||
} else {
|
||||
body["max_tokens"] = v.into();
|
||||
}
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["top_p"] = v.into();
|
||||
}
|
||||
if stream {
|
||||
body["stream"] = true.into();
|
||||
}
|
||||
if let Some(functions) = functions {
|
||||
body["tools"] = functions
|
||||
.iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": v,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value {
|
||||
json!({
|
||||
"input": data.texts,
|
||||
"model": model.real_name()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let text = data["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or_default();
|
||||
|
||||
let reasoning = data["choices"][0]["message"]["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| data["choices"][0]["message"]["reasoning"].as_str())
|
||||
.unwrap_or_default()
|
||||
.trim();
|
||||
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
|
||||
for call in calls {
|
||||
if let (Some(name), Some(arguments), Some(id)) = (
|
||||
call["function"]["name"].as_str(),
|
||||
call["function"]["arguments"].as_str(),
|
||||
call["id"].as_str(),
|
||||
) {
|
||||
let arguments: Value = arguments.parse().with_context(|| {
|
||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
||||
})?;
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
arguments,
|
||||
Some(id.to_string()),
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
let text = if !reasoning.is_empty() {
|
||||
format!("<think>\n{reasoning}\n</think>\n\n{text}")
|
||||
} else {
|
||||
text.to_string()
|
||||
};
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: data["id"].as_str().map(|v| v.to_string()),
|
||||
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
|
||||
output_tokens: data["usage"]["completion_tokens"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn normalize_function_id(value: &str) -> Option<String> {
|
||||
if value.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(value.to_string())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAICompatibleConfig {
|
||||
pub name: Option<String>,
|
||||
pub api_base: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl OpenAICompatibleClient {
|
||||
config_get_fn!(api_base, get_api_base);
|
||||
config_get_fn!(api_key, get_api_key);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 0] = [];
|
||||
}
|
||||
|
||||
impl_client_trait!(
|
||||
OpenAICompatibleClient,
|
||||
(
|
||||
prepare_chat_completions,
|
||||
openai_chat_completions,
|
||||
openai_chat_completions_streaming
|
||||
),
|
||||
(prepare_embeddings, openai_embeddings),
|
||||
(prepare_rerank, generic_rerank),
|
||||
);
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &OpenAICompatibleClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = format!("{api_base}/chat/completions");
|
||||
|
||||
let body = openai_build_chat_completions_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(
|
||||
self_: &OpenAICompatibleClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = format!("{api_base}/embeddings");
|
||||
|
||||
let body = openai_build_embeddings_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
|
||||
let api_key = self_.get_api_key().ok();
|
||||
let api_base = get_api_base_ext(self_)?;
|
||||
|
||||
let url = if self_.name().starts_with("ernie") {
|
||||
format!("{api_base}/rerankers")
|
||||
} else {
|
||||
format!("{api_base}/rerank")
|
||||
};
|
||||
|
||||
let body = generic_build_rerank_body(data, &self_.model);
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request_data.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result<String> {
|
||||
let api_base = match self_.get_api_base() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
match OPENAI_COMPATIBLE_PROVIDERS
|
||||
.into_iter()
|
||||
.find_map(|(name, api_base)| {
|
||||
if name == self_.model.client_name() {
|
||||
Some(api_base.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
Some(v) => v,
|
||||
None => return Err(err),
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(api_base.trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let mut data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
if data.get("results").is_none() && data.get("data").is_some() {
|
||||
if let Some(data_obj) = data.as_object_mut() {
|
||||
if let Some(value) = data_obj.remove("data") {
|
||||
data_obj.insert("results".to_string(), value);
|
||||
}
|
||||
}
|
||||
}
|
||||
let res_body: GenericRerankResBody =
|
||||
serde_json::from_value(data).context("Invalid rerank data")?;
|
||||
Ok(res_body.results)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GenericRerankResBody {
|
||||
pub results: RerankOutput,
|
||||
}
|
||||
|
||||
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
|
||||
let RerankData {
|
||||
query,
|
||||
documents,
|
||||
top_n,
|
||||
} = data;
|
||||
|
||||
let mut body = json!({
|
||||
"model": model.real_name(),
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
});
|
||||
if model.client_name().starts_with("voyageai") {
|
||||
body["top_k"] = (*top_n).into()
|
||||
} else {
|
||||
body["top_n"] = (*top_n).into()
|
||||
}
|
||||
body
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
use super::{catch_error, ToolCall};
|
||||
use crate::utils::AbortSignal;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
|
||||
pub struct SseHandler {
|
||||
sender: UnboundedSender<SseEvent>,
|
||||
abort_signal: AbortSignal,
|
||||
buffer: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
impl SseHandler {
|
||||
pub fn new(sender: UnboundedSender<SseEvent>, abort_signal: AbortSignal) -> Self {
|
||||
Self {
|
||||
sender,
|
||||
abort_signal,
|
||||
buffer: String::new(),
|
||||
tool_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text(&mut self, text: &str) -> Result<()> {
|
||||
// debug!("HandleText: {}", text);
|
||||
if text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.buffer.push_str(text);
|
||||
let ret = self
|
||||
.sender
|
||||
.send(SseEvent::Text(text.to_string()))
|
||||
.with_context(|| "Failed to send SseEvent:Text");
|
||||
if let Err(err) = ret {
|
||||
if self.abort_signal.aborted() {
|
||||
return Ok(());
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn done(&mut self) {
|
||||
// debug!("HandleDone");
|
||||
let ret = self.sender.send(SseEvent::Done);
|
||||
if ret.is_err() {
|
||||
if self.abort_signal.aborted() {
|
||||
return;
|
||||
}
|
||||
warn!("Failed to send SseEvent:Done");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
||||
// debug!("HandleCall: {:?}", call);
|
||||
self.tool_calls.push(call);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn abort(&self) -> AbortSignal {
|
||||
self.abort_signal.clone()
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> &[ToolCall] {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
pub fn take(self) -> (String, Vec<ToolCall>) {
|
||||
let Self {
|
||||
buffer, tool_calls, ..
|
||||
} = self;
|
||||
(buffer, tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SseEvent {
|
||||
Text(String),
|
||||
Done,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SseMessage {
|
||||
#[allow(unused)]
|
||||
pub event: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
|
||||
where
|
||||
F: FnMut(SseMessage) -> Result<bool>,
|
||||
{
|
||||
let mut es = builder.eventsource()?;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Open) => {}
|
||||
Ok(Event::Message(message)) => {
|
||||
let message = SseMessage {
|
||||
event: message.event,
|
||||
data: message.data,
|
||||
};
|
||||
if handle(message)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
match err {
|
||||
EventSourceError::StreamEnded => {}
|
||||
EventSourceError::InvalidStatusCode(status, res) => {
|
||||
let text = res.text().await?;
|
||||
let data: Value = match text.parse() {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"Invalid response data: {text} (status: {})",
|
||||
status.as_u16()
|
||||
);
|
||||
}
|
||||
};
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
EventSourceError::InvalidContentType(header_value, res) => {
|
||||
let text = res.text().await?;
|
||||
bail!(
|
||||
"Invalid response event-stream. content-type: {}, data: {text}",
|
||||
header_value.to_str().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
bail!("{}", err);
|
||||
}
|
||||
}
|
||||
es.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn json_stream<S, F, E>(mut stream: S, mut handle: F) -> Result<()>
|
||||
where
|
||||
S: Stream<Item = Result<bytes::Bytes, E>> + Unpin,
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
E: std::error::Error,
|
||||
{
|
||||
let mut parser = JsonStreamParser::default();
|
||||
let mut unparsed_bytes = vec![];
|
||||
while let Some(chunk_bytes) = stream.next().await {
|
||||
let chunk_bytes =
|
||||
chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?;
|
||||
unparsed_bytes.extend(chunk_bytes);
|
||||
match std::str::from_utf8(&unparsed_bytes) {
|
||||
Ok(text) => {
|
||||
parser.process(text, &mut handle)?;
|
||||
unparsed_bytes.clear();
|
||||
}
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if !unparsed_bytes.is_empty() {
|
||||
let text = std::str::from_utf8(&unparsed_bytes)?;
|
||||
parser.process(text, &mut handle)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct JsonStreamParser {
|
||||
buffer: Vec<char>,
|
||||
cursor: usize,
|
||||
start: Option<usize>,
|
||||
balances: Vec<char>,
|
||||
quoting: bool,
|
||||
escape: bool,
|
||||
}
|
||||
|
||||
impl JsonStreamParser {
|
||||
fn process<F>(&mut self, text: &str, handle: &mut F) -> Result<()>
|
||||
where
|
||||
F: FnMut(&str) -> Result<()>,
|
||||
{
|
||||
self.buffer.extend(text.chars());
|
||||
|
||||
for i in self.cursor..self.buffer.len() {
|
||||
let ch = self.buffer[i];
|
||||
if self.quoting {
|
||||
if ch == '\\' {
|
||||
self.escape = !self.escape;
|
||||
} else {
|
||||
if !self.escape && ch == '"' {
|
||||
self.quoting = false;
|
||||
}
|
||||
self.escape = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'"' => {
|
||||
self.quoting = true;
|
||||
self.escape = false;
|
||||
}
|
||||
'{' => {
|
||||
if self.balances.is_empty() {
|
||||
self.start = Some(i);
|
||||
}
|
||||
self.balances.push(ch);
|
||||
}
|
||||
'[' => {
|
||||
if self.start.is_some() {
|
||||
self.balances.push(ch);
|
||||
}
|
||||
}
|
||||
'}' => {
|
||||
self.balances.pop();
|
||||
if self.balances.is_empty() {
|
||||
if let Some(start) = self.start.take() {
|
||||
let value: String = self.buffer[start..=i].iter().collect();
|
||||
handle(&value)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
']' => {
|
||||
self.balances.pop();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.cursor = self.buffer.len();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream;
|
||||
use rand::Rng;
|
||||
|
||||
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
|
||||
let mut rng = rand::rng();
|
||||
let len = text.len();
|
||||
let cut1 = rng.random_range(1..len - 1);
|
||||
let cut2 = rng.random_range(cut1 + 1..len);
|
||||
let chunk1 = text.as_bytes()[..cut1].to_vec();
|
||||
let chunk2 = text.as_bytes()[cut1..cut2].to_vec();
|
||||
let chunk3 = text.as_bytes()[cut2..].to_vec();
|
||||
vec![chunk1, chunk2, chunk3]
|
||||
}
|
||||
|
||||
macro_rules! assert_json_stream {
|
||||
($input:expr, $output:expr) => {
|
||||
let chunks: Vec<_> = split_chunks($input)
|
||||
.into_iter()
|
||||
.map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk)))
|
||||
.collect();
|
||||
let stream = stream::iter(chunks);
|
||||
let mut output = vec![];
|
||||
let ret = json_stream(stream, |data| {
|
||||
output.push(data.to_string());
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
assert!(ret.is_ok());
|
||||
assert_eq!($output.replace("\r\n", "\n"), output.join("\n"))
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_stream_ndjson() {
|
||||
let data = r#"{"key": "value"}
|
||||
{"key": "value2"}
|
||||
{"key": "value3"}"#;
|
||||
assert_json_stream!(data, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_stream_array() {
|
||||
let input = r#"[
|
||||
{"key": "value"},
|
||||
{"key": "value2"},
|
||||
{"key": "value3"},"#;
|
||||
let output = r#"{"key": "value"}
|
||||
{"key": "value2"}
|
||||
{"key": "value3"}"#;
|
||||
assert_json_stream!(input, output);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,537 @@
|
||||
use super::access_token::*;
|
||||
use super::claude::*;
|
||||
use super::openai::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use chrono::{Duration, Utc};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::{path::PathBuf, str::FromStr};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct VertexAIConfig {
|
||||
pub name: Option<String>,
|
||||
pub project_id: Option<String>,
|
||||
pub location: Option<String>,
|
||||
pub adc_file: Option<String>,
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelData>,
|
||||
pub patch: Option<RequestPatch>,
|
||||
pub extra: Option<ExtraConfig>,
|
||||
}
|
||||
|
||||
impl VertexAIClient {
|
||||
config_get_fn!(project_id, get_project_id);
|
||||
config_get_fn!(location, get_location);
|
||||
|
||||
pub const PROMPTS: [PromptAction<'static>; 2] = [
|
||||
("project_id", "Project ID", None),
|
||||
("location", "Location", None),
|
||||
];
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Client for VertexAIClient {
|
||||
client_common_fns!();
|
||||
|
||||
async fn chat_completions_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let model = self.model();
|
||||
let model_category = ModelCategory::from_str(model.real_name())?;
|
||||
let request_data = prepare_chat_completions(self, data, &model_category)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
match model_category {
|
||||
ModelCategory::Gemini => gemini_chat_completions(builder, model).await,
|
||||
ModelCategory::Claude => claude_chat_completions(builder, model).await,
|
||||
ModelCategory::Mistral => openai_chat_completions(builder, model).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_completions_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: ChatCompletionsData,
|
||||
) -> Result<()> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let model = self.model();
|
||||
let model_category = ModelCategory::from_str(model.real_name())?;
|
||||
let request_data = prepare_chat_completions(self, data, &model_category)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
match model_category {
|
||||
ModelCategory::Gemini => {
|
||||
gemini_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
ModelCategory::Claude => {
|
||||
claude_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
openai_chat_completions_streaming(builder, handler, model).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn embeddings_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: &EmbeddingsData,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let request_data = prepare_embeddings(self, data)?;
|
||||
let builder = self.request_builder(client, request_data);
|
||||
embeddings(builder, self.model()).await
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_chat_completions(
|
||||
self_: &VertexAIClient,
|
||||
data: ChatCompletionsData,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<RequestData> {
|
||||
let project_id = self_.get_project_id()?;
|
||||
let location = self_.get_location()?;
|
||||
let access_token = get_access_token(self_.name())?;
|
||||
|
||||
let base_url = if location == "global" {
|
||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
||||
} else {
|
||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
||||
};
|
||||
|
||||
let model_name = self_.model.real_name();
|
||||
|
||||
let url = match model_category {
|
||||
ModelCategory::Gemini => {
|
||||
let func = match data.stream {
|
||||
true => "streamGenerateContent",
|
||||
false => "generateContent",
|
||||
};
|
||||
format!("{base_url}/google/models/{model_name}:{func}")
|
||||
}
|
||||
ModelCategory::Claude => {
|
||||
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
let func = match data.stream {
|
||||
true => "streamRawPredict",
|
||||
false => "rawPredict",
|
||||
};
|
||||
format!("{base_url}/mistralai/models/{model_name}:{func}")
|
||||
}
|
||||
};
|
||||
|
||||
let body = match model_category {
|
||||
ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?,
|
||||
ModelCategory::Claude => {
|
||||
let mut body = claude_build_chat_completions_body(data, &self_.model)?;
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj.remove("model");
|
||||
}
|
||||
body["anthropic_version"] = "vertex-2023-10-16".into();
|
||||
body
|
||||
}
|
||||
ModelCategory::Mistral => {
|
||||
let mut body = openai_build_chat_completions_body(data, &self_.model);
|
||||
if let Some(body_obj) = body.as_object_mut() {
|
||||
body_obj["model"] = strip_model_version(self_.model.real_name()).into();
|
||||
}
|
||||
body
|
||||
}
|
||||
};
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(access_token);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<RequestData> {
|
||||
let project_id = self_.get_project_id()?;
|
||||
let location = self_.get_location()?;
|
||||
let access_token = get_access_token(self_.name())?;
|
||||
|
||||
let base_url = if location == "global" {
|
||||
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
|
||||
} else {
|
||||
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
|
||||
};
|
||||
let url = format!(
|
||||
"{base_url}/google/models/{}:predict",
|
||||
self_.model.real_name()
|
||||
);
|
||||
|
||||
let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect();
|
||||
|
||||
let body = json!({
|
||||
"instances": instances,
|
||||
});
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
|
||||
request_data.bearer_auth(access_token);
|
||||
|
||||
Ok(request_data)
|
||||
}
|
||||
|
||||
pub async fn gemini_chat_completions(
|
||||
builder: RequestBuilder,
|
||||
_model: &Model,
|
||||
) -> Result<ChatCompletionsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
debug!("non-stream-data: {data}");
|
||||
gemini_extract_chat_completions_text(&data)
|
||||
}
|
||||
|
||||
pub async fn gemini_chat_completions_streaming(
|
||||
builder: RequestBuilder,
|
||||
handler: &mut SseHandler,
|
||||
_model: &Model,
|
||||
) -> Result<()> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
if !status.is_success() {
|
||||
let data: Value = res.json().await?;
|
||||
catch_error(&data, status.as_u16())?;
|
||||
} else {
|
||||
let handle = |value: &str| -> Result<()> {
|
||||
let data: Value = serde_json::from_str(value)?;
|
||||
debug!("stream-data: {data}");
|
||||
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
if i > 0 {
|
||||
handler.text("\n\n")?;
|
||||
}
|
||||
handler.text(text)?;
|
||||
} else if let (Some(name), Some(args)) = (
|
||||
part["functionCall"]["name"].as_str(),
|
||||
part["functionCall"]["args"].as_object(),
|
||||
) {
|
||||
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
|
||||
}
|
||||
}
|
||||
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
|
||||
.as_str()
|
||||
.or_else(|| data["candidates"][0]["finishReason"].as_str())
|
||||
{
|
||||
bail!("Blocked due to safety")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
};
|
||||
json_stream(res.bytes_stream(), handle).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
|
||||
let res = builder.send().await?;
|
||||
let status = res.status();
|
||||
let data: Value = res.json().await?;
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
let res_body: EmbeddingsResBody =
|
||||
serde_json::from_value(data).context("Invalid embeddings data")?;
|
||||
let output = res_body
|
||||
.predictions
|
||||
.into_iter()
|
||||
.map(|v| v.embeddings.values)
|
||||
.collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBody {
|
||||
predictions: Vec<EmbeddingsResBodyPrediction>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyPrediction {
|
||||
embeddings: EmbeddingsResBodyPredictionEmbeddings,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingsResBodyPredictionEmbeddings {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
|
||||
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
let mut text_parts = vec![];
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
|
||||
for part in parts {
|
||||
if let Some(text) = part["text"].as_str() {
|
||||
text_parts.push(text);
|
||||
}
|
||||
if let (Some(name), Some(args)) = (
|
||||
part["functionCall"]["name"].as_str(),
|
||||
part["functionCall"]["args"].as_object(),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = text_parts.join("\n\n");
|
||||
if text.is_empty() && tool_calls.is_empty() {
|
||||
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
|
||||
.as_str()
|
||||
.or_else(|| data["candidates"][0]["finishReason"].as_str())
|
||||
{
|
||||
bail!("Blocked due to safety")
|
||||
} else {
|
||||
bail!("Invalid response data: {data}");
|
||||
}
|
||||
}
|
||||
let output = ChatCompletionsOutput {
|
||||
text,
|
||||
tool_calls,
|
||||
id: None,
|
||||
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
|
||||
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn gemini_build_chat_completions_body(
|
||||
data: ChatCompletionsData,
|
||||
model: &Model,
|
||||
) -> Result<Value> {
|
||||
let ChatCompletionsData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream: _,
|
||||
} = data;
|
||||
|
||||
let system_message = extract_system_message(&mut messages);
|
||||
|
||||
let mut network_image_urls = vec![];
|
||||
let contents: Vec<Value> = messages
|
||||
.into_iter()
|
||||
.flat_map(|message| {
|
||||
let Message { role, content } = message;
|
||||
let role = match role {
|
||||
MessageRole::User => "user",
|
||||
_ => "model",
|
||||
};
|
||||
match content {
|
||||
MessageContent::Text(text) => vec![json!({
|
||||
"role": role,
|
||||
"parts": [{ "text": text }]
|
||||
})],
|
||||
MessageContent::Array(list) => {
|
||||
let parts: Vec<Value> = list
|
||||
.into_iter()
|
||||
.map(|item| match item {
|
||||
MessageContentPart::Text { text } => json!({"text": text}),
|
||||
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
|
||||
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
|
||||
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
|
||||
} else {
|
||||
network_image_urls.push(url.clone());
|
||||
json!({ "url": url })
|
||||
}
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
vec![json!({ "role": role, "parts": parts })]
|
||||
},
|
||||
MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => {
|
||||
let model_parts: Vec<Value> = tool_results.iter().map(|tool_result| {
|
||||
json!({
|
||||
"functionCall": {
|
||||
"name": tool_result.call.name,
|
||||
"args": tool_result.call.arguments,
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
let function_parts: Vec<Value> = tool_results.into_iter().map(|tool_result| {
|
||||
json!({
|
||||
"functionResponse": {
|
||||
"name": tool_result.call.name,
|
||||
"response": {
|
||||
"name": tool_result.call.name,
|
||||
"content": tool_result.output,
|
||||
}
|
||||
}
|
||||
})
|
||||
}).collect();
|
||||
vec![
|
||||
json!({ "role": "model", "parts": model_parts }),
|
||||
json!({ "role": "function", "parts": function_parts }),
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !network_image_urls.is_empty() {
|
||||
bail!(
|
||||
"The model does not support network images: {:?}",
|
||||
network_image_urls
|
||||
);
|
||||
}
|
||||
|
||||
let mut body = json!({ "contents": contents, "generationConfig": {} });
|
||||
|
||||
if let Some(v) = system_message {
|
||||
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
|
||||
}
|
||||
|
||||
if let Some(v) = model.max_tokens_param() {
|
||||
body["generationConfig"]["maxOutputTokens"] = v.into();
|
||||
}
|
||||
if let Some(v) = temperature {
|
||||
body["generationConfig"]["temperature"] = v.into();
|
||||
}
|
||||
if let Some(v) = top_p {
|
||||
body["generationConfig"]["topP"] = v.into();
|
||||
}
|
||||
|
||||
if let Some(functions) = functions {
|
||||
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
|
||||
let function_declarations: Vec<_> = functions
|
||||
.into_iter()
|
||||
.map(|function| {
|
||||
if function.parameters.is_empty_properties() {
|
||||
json!({
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
})
|
||||
} else {
|
||||
json!(function)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelCategory {
|
||||
Gemini,
|
||||
Claude,
|
||||
Mistral,
|
||||
}
|
||||
|
||||
impl FromStr for ModelCategory {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
if s.starts_with("gemini") {
|
||||
Ok(ModelCategory::Gemini)
|
||||
} else if s.starts_with("claude") {
|
||||
Ok(ModelCategory::Claude)
|
||||
} else if s.starts_with("mistral") || s.starts_with("codestral") {
|
||||
Ok(ModelCategory::Mistral)
|
||||
} else {
|
||||
unsupported_model!(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare_gcloud_access_token(
|
||||
client: &reqwest::Client,
|
||||
client_name: &str,
|
||||
adc_file: &Option<String>,
|
||||
) -> Result<()> {
|
||||
if !is_valid_access_token(client_name) {
|
||||
let (token, expires_in) = fetch_access_token(client, adc_file)
|
||||
.await
|
||||
.with_context(|| "Failed to fetch access token")?;
|
||||
let expires_at = Utc::now()
|
||||
+ Duration::try_seconds(expires_in)
|
||||
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
|
||||
set_access_token(client_name, token, expires_at.timestamp())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_access_token(
|
||||
client: &reqwest::Client,
|
||||
file: &Option<String>,
|
||||
) -> Result<(String, i64)> {
|
||||
let credentials = load_adc(file).await?;
|
||||
let value: Value = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.json(&credentials)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let (Some(access_token), Some(expires_in)) =
|
||||
(value["access_token"].as_str(), value["expires_in"].as_i64())
|
||||
{
|
||||
Ok((access_token.to_string(), expires_in))
|
||||
} else if let Some(err_msg) = value["error_description"].as_str() {
|
||||
bail!("{err_msg}")
|
||||
} else {
|
||||
bail!("Invalid response data: {value}")
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_adc(file: &Option<String>) -> Result<Value> {
|
||||
let adc_file = file
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.or_else(default_adc_file)
|
||||
.ok_or_else(|| anyhow!("No application_default_credentials.json"))?;
|
||||
let data = tokio::fs::read_to_string(adc_file).await?;
|
||||
let data: Value = serde_json::from_str(&data)?;
|
||||
if let (Some(client_id), Some(client_secret), Some(refresh_token)) = (
|
||||
data["client_id"].as_str(),
|
||||
data["client_secret"].as_str(),
|
||||
data["refresh_token"].as_str(),
|
||||
) {
|
||||
Ok(json!({
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}))
|
||||
} else {
|
||||
bail!("Invalid application_default_credentials.json")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn default_adc_file() -> Option<PathBuf> {
|
||||
let mut path = dirs::home_dir()?;
|
||||
path.push(".config");
|
||||
path.push("gcloud");
|
||||
path.push("application_default_credentials.json");
|
||||
Some(path)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn default_adc_file() -> Option<PathBuf> {
|
||||
let mut path = dirs::config_dir()?;
|
||||
path.push("gcloud");
|
||||
path.push("application_default_credentials.json");
|
||||
Some(path)
|
||||
}
|
||||
|
||||
fn strip_model_version(name: &str) -> &str {
|
||||
match name.split_once('@') {
|
||||
Some((v, _)) => v,
|
||||
None => name,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,570 @@
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
client::Model,
|
||||
function::{run_llm_function, Functions},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use inquire::{validator::Validation, Text};
|
||||
use std::{fs::read_to_string, path::Path};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_AGENT_NAME: &str = "rag";
|
||||
|
||||
pub type AgentVariables = IndexMap<String, String>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Agent {
|
||||
name: String,
|
||||
config: AgentConfig,
|
||||
shared_variables: AgentVariables,
|
||||
session_variables: Option<AgentVariables>,
|
||||
shared_dynamic_instructions: Option<String>,
|
||||
session_dynamic_instructions: Option<String>,
|
||||
functions: Functions,
|
||||
rag: Option<Arc<Rag>>,
|
||||
model: Model,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub async fn init(
|
||||
config: &GlobalConfig,
|
||||
name: &str,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
let agent_data_dir = Config::agent_data_dir(name);
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
||||
let config_path = Config::agent_config_file(name);
|
||||
let mut agent_config = if config_path.exists() {
|
||||
AgentConfig::load(&config_path)?
|
||||
} else {
|
||||
bail!("Agent config file not found at '{}'", config_path.display())
|
||||
};
|
||||
let mut functions = Functions::init_agent(name, &agent_config.global_tools)?;
|
||||
|
||||
config.write().functions.clear_mcp_meta_functions();
|
||||
let mcp_servers =
|
||||
(!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(","));
|
||||
let registry = config
|
||||
.write()
|
||||
.mcp_registry
|
||||
.take()
|
||||
.expect("MCP registry should be initialized");
|
||||
let new_mcp_registry =
|
||||
McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?;
|
||||
|
||||
if !new_mcp_registry.is_empty() {
|
||||
functions.append_mcp_meta_functions(new_mcp_registry.list_servers());
|
||||
}
|
||||
|
||||
config.write().mcp_registry = Some(new_mcp_registry);
|
||||
agent_config.replace_tools_placeholder(&functions);
|
||||
|
||||
agent_config.load_envs(&config.read());
|
||||
|
||||
let model = {
|
||||
let config = config.read();
|
||||
match agent_config.model_id.as_ref() {
|
||||
Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?,
|
||||
None => {
|
||||
if agent_config.temperature.is_none() {
|
||||
agent_config.temperature = config.temperature;
|
||||
}
|
||||
if agent_config.top_p.is_none() {
|
||||
agent_config.top_p = config.top_p;
|
||||
}
|
||||
config.current_model().clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let rag = if rag_path.exists() {
|
||||
Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?))
|
||||
} else if !agent_config.documents.is_empty() && !config.read().info_flag {
|
||||
let mut ans = false;
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
ans = Confirm::new("The agent has documents attached, init RAG?")
|
||||
.with_default(true)
|
||||
.prompt()?;
|
||||
}
|
||||
if ans {
|
||||
let mut document_paths = vec![];
|
||||
for path in &agent_config.documents {
|
||||
if is_url(path) {
|
||||
document_paths.push(path.to_string());
|
||||
} else if is_loader_protocol(&loaders, path) {
|
||||
let (protocol, document_path) = path
|
||||
.split_once(':')
|
||||
.with_context(|| "Invalid loader protocol path")?;
|
||||
let resolved_path = resolve_home_dir(document_path);
|
||||
let new_path = if Path::new(&resolved_path).is_relative() {
|
||||
safe_join_path(&agent_data_dir, resolved_path)
|
||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
|
||||
} else {
|
||||
PathBuf::from(&resolved_path)
|
||||
};
|
||||
document_paths.push(format!("{}:{}", protocol, new_path.display()));
|
||||
} else if Path::new(&resolve_home_dir(path)).is_relative() {
|
||||
let new_path = safe_join_path(&agent_data_dir, path)
|
||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
|
||||
document_paths.push(new_path.display().to_string())
|
||||
} else {
|
||||
document_paths.push(path.to_string())
|
||||
}
|
||||
}
|
||||
let rag =
|
||||
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
|
||||
Some(Arc::new(rag))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
config: agent_config,
|
||||
shared_variables: Default::default(),
|
||||
session_variables: None,
|
||||
shared_dynamic_instructions: None,
|
||||
session_dynamic_instructions: None,
|
||||
functions,
|
||||
rag,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn init_agent_variables(
|
||||
agent_variables: &[AgentVariable],
|
||||
no_interaction: bool,
|
||||
) -> Result<AgentVariables> {
|
||||
let mut output = IndexMap::new();
|
||||
if agent_variables.is_empty() {
|
||||
return Ok(output);
|
||||
}
|
||||
let mut printed = false;
|
||||
let mut unset_variables = vec![];
|
||||
for agent_variable in agent_variables {
|
||||
let key = agent_variable.name.clone();
|
||||
if let Some(value) = agent_variable.default.clone() {
|
||||
output.insert(key, value);
|
||||
continue;
|
||||
}
|
||||
if no_interaction {
|
||||
continue;
|
||||
}
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
if !printed {
|
||||
println!("⚙ Init agent variables...");
|
||||
printed = true;
|
||||
}
|
||||
let value = Text::new(&format!(
|
||||
"{} ({}):",
|
||||
agent_variable.name, agent_variable.description
|
||||
))
|
||||
.with_validator(|input: &str| {
|
||||
if input.trim().is_empty() {
|
||||
Ok(Validation::Invalid("This field is required".into()))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
output.insert(key, value);
|
||||
} else {
|
||||
unset_variables.push(agent_variable)
|
||||
}
|
||||
}
|
||||
if !unset_variables.is_empty() {
|
||||
bail!(
|
||||
"The following agent variables are required:\n{}",
|
||||
unset_variables
|
||||
.iter()
|
||||
.map(|v| format!(" - {}: {}", v.name, v.description))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn export(&self) -> Result<String> {
|
||||
let mut value = json!({});
|
||||
value["name"] = json!(self.name());
|
||||
let variables = self.variables();
|
||||
if !variables.is_empty() {
|
||||
value["variables"] = serde_json::to_value(variables)?;
|
||||
}
|
||||
value["config"] = json!(self.config);
|
||||
let mut config = self.config.clone();
|
||||
config.instructions = self.interpolated_instructions();
|
||||
value["definition"] = json!(config);
|
||||
value["data_dir"] = Config::agent_data_dir(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
value["config_file"] = Config::agent_config_file(&self.name)
|
||||
.display()
|
||||
.to_string()
|
||||
.into();
|
||||
let data = serde_yaml::to_string(&value)?;
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
pub fn banner(&self) -> String {
|
||||
self.config.banner()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn functions(&self) -> &Functions {
|
||||
&self.functions
|
||||
}
|
||||
|
||||
pub fn rag(&self) -> Option<Arc<Rag>> {
|
||||
self.rag.clone()
|
||||
}
|
||||
|
||||
pub fn conversation_starters(&self) -> &[String] {
|
||||
&self.config.conversation_starters
|
||||
}
|
||||
|
||||
pub fn interpolated_instructions(&self) -> String {
|
||||
let mut output = self
|
||||
.session_dynamic_instructions
|
||||
.clone()
|
||||
.or_else(|| self.shared_dynamic_instructions.clone())
|
||||
.unwrap_or_else(|| self.config.instructions.clone());
|
||||
for (k, v) in self.variables() {
|
||||
output = output.replace(&format!("{{{{{k}}}}}"), v)
|
||||
}
|
||||
interpolate_variables(&mut output);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn agent_prelude(&self) -> Option<&str> {
|
||||
self.config.agent_prelude.as_deref()
|
||||
}
|
||||
|
||||
pub fn variables(&self) -> &AgentVariables {
|
||||
match &self.session_variables {
|
||||
Some(variables) => variables,
|
||||
None => &self.shared_variables,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn variable_envs(&self) -> HashMap<String, String> {
|
||||
self.variables()
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
|
||||
v.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn shared_variables(&self) -> &AgentVariables {
|
||||
&self.shared_variables
|
||||
}
|
||||
|
||||
pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) {
|
||||
self.shared_variables = shared_variables;
|
||||
}
|
||||
|
||||
pub fn set_session_variables(&mut self, session_variables: AgentVariables) {
|
||||
self.session_variables = Some(session_variables);
|
||||
}
|
||||
|
||||
pub fn defined_variables(&self) -> &[AgentVariable] {
|
||||
&self.config.variables
|
||||
}
|
||||
|
||||
pub fn exit_session(&mut self) {
|
||||
self.session_variables = None;
|
||||
self.session_dynamic_instructions = None;
|
||||
}
|
||||
|
||||
pub fn is_dynamic_instructions(&self) -> bool {
|
||||
self.config.dynamic_instructions
|
||||
}
|
||||
|
||||
pub fn update_shared_dynamic_instructions(&mut self, force: bool) -> Result<()> {
|
||||
if self.is_dynamic_instructions() && (force || self.shared_dynamic_instructions.is_none()) {
|
||||
self.shared_dynamic_instructions = Some(self.run_instructions_fn()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_session_dynamic_instructions(&mut self, value: Option<String>) -> Result<()> {
|
||||
if self.is_dynamic_instructions() {
|
||||
let value = match value {
|
||||
Some(v) => v,
|
||||
None => self.run_instructions_fn()?,
|
||||
};
|
||||
self.session_dynamic_instructions = Some(value);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_instructions_fn(&self) -> Result<String> {
|
||||
let value = run_llm_function(
|
||||
self.name().to_string(),
|
||||
vec!["_instructions".into(), "{}".into()],
|
||||
self.variable_envs(),
|
||||
)?;
|
||||
match value {
|
||||
Some(v) => Ok(v),
|
||||
_ => bail!("No return value from '_instructions' function"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Agent {
|
||||
fn to_role(&self) -> Role {
|
||||
let prompt = self.interpolated_instructions();
|
||||
let mut role = Role::new("", &prompt);
|
||||
role.sync(self);
|
||||
role
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.config.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.config.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.config.global_tools.clone().join(",").into()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.config.mcp_servers.clone().join(",").into()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
self.config.model_id = Some(model.id());
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
self.config.temperature = value;
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
self.config.top_p = value;
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
match value {
|
||||
Some(tools) => {
|
||||
let tools = tools
|
||||
.split(',')
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
self.config.global_tools = tools;
|
||||
}
|
||||
None => {
|
||||
self.config.global_tools.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
match value {
|
||||
Some(servers) => {
|
||||
let servers = servers
|
||||
.split(',')
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
self.config.mcp_servers = servers;
|
||||
}
|
||||
None => {
|
||||
self.config.mcp_servers.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct AgentConfig {
|
||||
pub name: String,
|
||||
#[serde(rename(serialize = "model", deserialize = "model"))]
|
||||
pub model_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub agent_prelude: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
#[serde(default)]
|
||||
pub version: String,
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub global_tools: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub instructions: String,
|
||||
#[serde(default)]
|
||||
pub dynamic_instructions: bool,
|
||||
#[serde(default)]
|
||||
pub variables: Vec<AgentVariable>,
|
||||
#[serde(default)]
|
||||
pub conversation_starters: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub documents: Vec<String>,
|
||||
}
|
||||
|
||||
impl AgentConfig {
|
||||
pub fn load(path: &Path) -> Result<Self> {
|
||||
let contents = read_to_string(path)
|
||||
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
|
||||
let agent_config: Self = serde_yaml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to load agent config at '{}'", path.display()))?;
|
||||
|
||||
Ok(agent_config)
|
||||
}
|
||||
|
||||
fn load_envs(&mut self, config: &Config) {
|
||||
let name = &self.name;
|
||||
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));
|
||||
|
||||
if self.agent_prelude.is_none() {
|
||||
self.agent_prelude = config.agent_prelude.clone();
|
||||
}
|
||||
|
||||
if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
|
||||
self.model_id = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<f64>(&with_prefix("temperature")) {
|
||||
self.temperature = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<f64>(&with_prefix("top_p")) {
|
||||
self.top_p = v;
|
||||
}
|
||||
if let Some(v) = read_env_value::<String>(&with_prefix("agent_prelude")) {
|
||||
self.agent_prelude = v;
|
||||
}
|
||||
if let Ok(v) = env::var(with_prefix("variables")) {
|
||||
if let Ok(v) = serde_json::from_str(&v) {
|
||||
self.variables = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn banner(&self) -> String {
|
||||
let AgentConfig {
|
||||
name,
|
||||
description,
|
||||
version,
|
||||
conversation_starters,
|
||||
..
|
||||
} = self;
|
||||
let starters = if conversation_starters.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let starters = conversation_starters
|
||||
.iter()
|
||||
.map(|v| format!("- {v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
format!(
|
||||
r#"
|
||||
|
||||
## Conversation Starters
|
||||
{starters}"#
|
||||
)
|
||||
};
|
||||
format!(
|
||||
r#"# {name} {version}
|
||||
{description}{starters}"#
|
||||
)
|
||||
}
|
||||
|
||||
fn replace_tools_placeholder(&mut self, functions: &Functions) {
|
||||
let tools_placeholder: &str = "{{__tools__}}";
|
||||
if self.instructions.contains(tools_placeholder) {
|
||||
let tools = functions
|
||||
.declarations()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
let description = match v.description.split_once('\n') {
|
||||
Some((v, _)) => v,
|
||||
None => &v.description,
|
||||
};
|
||||
format!("{}. {}: {description}", i + 1, v.name)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
self.instructions = self.instructions.replace(tools_placeholder, &tools);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct AgentVariable {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
pub fn list_agents() -> Vec<String> {
|
||||
let agents_file = Config::config_dir().join("agents.txt");
|
||||
let contents = match read_to_string(agents_file) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return vec![],
|
||||
};
|
||||
contents
|
||||
.split('\n')
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
None
|
||||
} else {
|
||||
Some(line.to_string())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>)> {
|
||||
let config_path = Config::agent_config_file(agent_name);
|
||||
if !config_path.exists() {
|
||||
return vec![];
|
||||
}
|
||||
let Ok(config) = AgentConfig::load(&config_path) else {
|
||||
return vec![];
|
||||
};
|
||||
config
|
||||
.variables
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let description = match &v.default {
|
||||
Some(default) => format!("{} [default: {default}]", v.description),
|
||||
None => v.description.clone(),
|
||||
};
|
||||
(format!("{}=", v.name), Some(description))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
use super::*;
|
||||
|
||||
use crate::client::{
|
||||
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
|
||||
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
|
||||
};
|
||||
use crate::function::ToolResult;
|
||||
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use indexmap::IndexSet;
|
||||
use std::{collections::HashMap, fs::File, io::Read};
|
||||
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
|
||||
|
||||
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
|
||||
const SUMMARY_MAX_WIDTH: usize = 80;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Input {
|
||||
config: GlobalConfig,
|
||||
text: String,
|
||||
raw: (String, Vec<String>),
|
||||
patched_text: Option<String>,
|
||||
last_reply: Option<String>,
|
||||
continue_output: Option<String>,
|
||||
regenerate: bool,
|
||||
medias: Vec<String>,
|
||||
data_urls: HashMap<String, String>,
|
||||
tool_calls: Option<MessageContentToolCalls>,
|
||||
role: Role,
|
||||
rag_name: Option<String>,
|
||||
with_session: bool,
|
||||
with_agent: bool,
|
||||
}
|
||||
|
||||
impl Input {
|
||||
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Self {
|
||||
config: config.clone(),
|
||||
text: text.to_string(),
|
||||
raw: (text.to_string(), vec![]),
|
||||
patched_text: None,
|
||||
last_reply: None,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias: Default::default(),
|
||||
data_urls: Default::default(),
|
||||
tool_calls: None,
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_files(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
) -> Result<Self> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
|
||||
resolve_paths(&loaders, paths)?;
|
||||
let mut last_reply = None;
|
||||
let (documents, medias, data_urls) = load_documents(
|
||||
&loaders,
|
||||
local_paths,
|
||||
remote_urls,
|
||||
external_cmds,
|
||||
protocol_paths,
|
||||
)
|
||||
.await
|
||||
.context("Failed to load files")?;
|
||||
let mut texts = vec![];
|
||||
if !raw_text.is_empty() {
|
||||
texts.push(raw_text.to_string());
|
||||
};
|
||||
if with_last_reply {
|
||||
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
|
||||
if !output.is_empty() {
|
||||
last_reply = Some(output.clone())
|
||||
} else if let Some(v) = input.last_reply.as_ref() {
|
||||
last_reply = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = last_reply.clone() {
|
||||
texts.push(format!("\n{v}"));
|
||||
}
|
||||
}
|
||||
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
|
||||
bail!("No last reply found");
|
||||
}
|
||||
}
|
||||
let documents_len = documents.len();
|
||||
for (kind, path, contents) in documents {
|
||||
if documents_len == 1 && raw_text.is_empty() {
|
||||
texts.push(format!("\n{contents}"));
|
||||
} else {
|
||||
texts.push(format!(
|
||||
"\n============ {kind}: {path} ============\n{contents}"
|
||||
));
|
||||
}
|
||||
}
|
||||
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
text: texts.join("\n"),
|
||||
raw: (raw_text.to_string(), raw_paths),
|
||||
patched_text: None,
|
||||
last_reply,
|
||||
continue_output: None,
|
||||
regenerate: false,
|
||||
medias,
|
||||
data_urls,
|
||||
tool_calls: Default::default(),
|
||||
role,
|
||||
rag_name: None,
|
||||
with_session,
|
||||
with_agent,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn from_files_with_spinner(
|
||||
config: &GlobalConfig,
|
||||
raw_text: &str,
|
||||
paths: Vec<String>,
|
||||
role: Option<Role>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
abortable_run_with_spinner(
|
||||
Input::from_files(config, raw_text, paths, role),
|
||||
"Loading files",
|
||||
abort_signal,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.text.is_empty() && self.medias.is_empty()
|
||||
}
|
||||
|
||||
pub fn data_urls(&self) -> HashMap<String, String> {
|
||||
self.data_urls.clone()
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
pub fn text(&self) -> String {
|
||||
match self.patched_text.clone() {
|
||||
Some(text) => text,
|
||||
None => self.text.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_patch(&mut self) {
|
||||
self.patched_text = None;
|
||||
}
|
||||
|
||||
pub fn set_text(&mut self, text: String) {
|
||||
self.text = text;
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> bool {
|
||||
self.config.read().stream && !self.role().model().no_stream()
|
||||
}
|
||||
|
||||
pub fn continue_output(&self) -> Option<&str> {
|
||||
self.continue_output.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_continue_output(&mut self, output: &str) {
|
||||
let output = match &self.continue_output {
|
||||
Some(v) => format!("{v}{output}"),
|
||||
None => output.to_string(),
|
||||
};
|
||||
self.continue_output = Some(output);
|
||||
}
|
||||
|
||||
pub fn regenerate(&self) -> bool {
|
||||
self.regenerate
|
||||
}
|
||||
|
||||
pub fn set_regenerate(&mut self) {
|
||||
let role = self.config.read().extract_role();
|
||||
if role.name() == self.role().name() {
|
||||
self.role = role;
|
||||
}
|
||||
self.regenerate = true;
|
||||
self.tool_calls = None;
|
||||
}
|
||||
|
||||
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
|
||||
if self.text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let rag = self.config.read().rag.clone();
|
||||
if let Some(rag) = rag {
|
||||
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
|
||||
self.patched_text = Some(result);
|
||||
self.rag_name = Some(rag.name().to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rag_name(&self) -> Option<&str> {
|
||||
self.rag_name.as_deref()
|
||||
}
|
||||
|
||||
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
|
||||
match self.tool_calls.as_mut() {
|
||||
Some(exist_tool_results) => {
|
||||
exist_tool_results.merge(tool_results, output);
|
||||
}
|
||||
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create_client(&self) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.config, Some(self.role().model().clone()))
|
||||
}
|
||||
|
||||
pub async fn fetch_chat_text(&self) -> Result<String> {
|
||||
let client = self.create_client()?;
|
||||
let text = client.chat_completions(self.clone()).await?.text;
|
||||
let text = strip_think_tag(&text).to_string();
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
pub fn prepare_completion_data(
|
||||
&self,
|
||||
model: &Model,
|
||||
stream: bool,
|
||||
) -> Result<ChatCompletionsData> {
|
||||
let mut messages = self.build_messages()?;
|
||||
patch_messages(&mut messages, model);
|
||||
model.guard_max_input_tokens(&messages)?;
|
||||
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
|
||||
let functions = self.config.read().select_functions(self.role());
|
||||
if let Some(vec) = &functions {
|
||||
for def in vec {
|
||||
debug!("Function definition: {:?}", def.name);
|
||||
}
|
||||
}
|
||||
Ok(ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn build_messages(&self) -> Result<Vec<Message>> {
|
||||
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.build_messages(self)
|
||||
} else {
|
||||
self.role().build_messages(self)
|
||||
};
|
||||
if let Some(tool_calls) = &self.tool_calls {
|
||||
messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::ToolCalls(tool_calls.clone()),
|
||||
))
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self) -> String {
|
||||
if let Some(session) = self.session(&self.config.read().session) {
|
||||
session.echo_messages(self)
|
||||
} else {
|
||||
self.role().echo_messages(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> &Role {
|
||||
&self.role
|
||||
}
|
||||
|
||||
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
|
||||
if self.with_session {
|
||||
session.as_ref()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
|
||||
if self.with_session {
|
||||
session.as_mut()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_agent(&self) -> bool {
|
||||
self.with_agent
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> String {
|
||||
let text: String = self
|
||||
.text
|
||||
.trim()
|
||||
.chars()
|
||||
.map(|c| if c.is_control() { ' ' } else { c })
|
||||
.collect();
|
||||
if text.width_cjk() > SUMMARY_MAX_WIDTH {
|
||||
let mut sum_width = 0;
|
||||
let mut chars = vec![];
|
||||
for c in text.chars() {
|
||||
sum_width += c.width_cjk().unwrap_or(1);
|
||||
if sum_width > SUMMARY_MAX_WIDTH - 3 {
|
||||
chars.extend(['.', '.', '.']);
|
||||
break;
|
||||
}
|
||||
chars.push(c);
|
||||
}
|
||||
chars.into_iter().collect()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> String {
|
||||
let (text, files) = &self.raw;
|
||||
let mut segments = files.to_vec();
|
||||
if !segments.is_empty() {
|
||||
segments.insert(0, ".file".into());
|
||||
}
|
||||
if !text.is_empty() {
|
||||
if !segments.is_empty() {
|
||||
segments.push("--".into());
|
||||
}
|
||||
segments.push(text.clone());
|
||||
}
|
||||
segments.join(" ")
|
||||
}
|
||||
|
||||
pub fn render(&self) -> String {
|
||||
let text = self.text();
|
||||
if self.medias.is_empty() {
|
||||
return text;
|
||||
}
|
||||
let tail_text = if text.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" -- {text}")
|
||||
};
|
||||
let files: Vec<String> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| resolve_data_url(&self.data_urls, url))
|
||||
.collect();
|
||||
format!(".file {}{}", files.join(" "), tail_text)
|
||||
}
|
||||
|
||||
pub fn message_content(&self) -> MessageContent {
|
||||
if self.medias.is_empty() {
|
||||
MessageContent::Text(self.text())
|
||||
} else {
|
||||
let mut list: Vec<MessageContentPart> = self
|
||||
.medias
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|url| MessageContentPart::ImageUrl {
|
||||
image_url: ImageUrl { url },
|
||||
})
|
||||
.collect();
|
||||
if !self.text.is_empty() {
|
||||
list.insert(0, MessageContentPart::Text { text: self.text() });
|
||||
}
|
||||
MessageContent::Array(list)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
|
||||
match role {
|
||||
Some(v) => (v, false, false),
|
||||
None => (
|
||||
config.extract_role(),
|
||||
config.session.is_some(),
|
||||
config.agent.is_some(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type ResolvePathsOutput = (
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
Vec<String>,
|
||||
bool,
|
||||
);
|
||||
|
||||
fn resolve_paths(
|
||||
loaders: &HashMap<String, String>,
|
||||
paths: Vec<String>,
|
||||
) -> Result<ResolvePathsOutput> {
|
||||
let mut raw_paths = IndexSet::new();
|
||||
let mut local_paths = IndexSet::new();
|
||||
let mut remote_urls = IndexSet::new();
|
||||
let mut external_cmds = IndexSet::new();
|
||||
let mut protocol_paths = IndexSet::new();
|
||||
let mut with_last_reply = false;
|
||||
for path in paths {
|
||||
if path == "%%" {
|
||||
with_last_reply = true;
|
||||
raw_paths.insert(path);
|
||||
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
|
||||
external_cmds.insert(path[1..path.len() - 1].to_string());
|
||||
raw_paths.insert(path);
|
||||
} else if is_url(&path) {
|
||||
if path.strip_suffix("**").is_some() {
|
||||
bail!("Invalid website '{path}'");
|
||||
}
|
||||
remote_urls.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else if is_loader_protocol(loaders, &path) {
|
||||
protocol_paths.insert(path.clone());
|
||||
raw_paths.insert(path);
|
||||
} else {
|
||||
let resolved_path = resolve_home_dir(&path);
|
||||
let absolute_path = to_absolute_path(&resolved_path)
|
||||
.with_context(|| format!("Invalid path '{path}'"))?;
|
||||
local_paths.insert(resolved_path);
|
||||
raw_paths.insert(absolute_path);
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
raw_paths.into_iter().collect(),
|
||||
local_paths.into_iter().collect(),
|
||||
remote_urls.into_iter().collect(),
|
||||
external_cmds.into_iter().collect(),
|
||||
protocol_paths.into_iter().collect(),
|
||||
with_last_reply,
|
||||
))
|
||||
}
|
||||
|
||||
async fn load_documents(
|
||||
loaders: &HashMap<String, String>,
|
||||
local_paths: Vec<String>,
|
||||
remote_urls: Vec<String>,
|
||||
external_cmds: Vec<String>,
|
||||
protocol_paths: Vec<String>,
|
||||
) -> Result<(
|
||||
Vec<(&'static str, String, String)>,
|
||||
Vec<String>,
|
||||
HashMap<String, String>,
|
||||
)> {
|
||||
let mut files = vec![];
|
||||
let mut medias = vec![];
|
||||
let mut data_urls = HashMap::new();
|
||||
|
||||
for cmd in external_cmds {
|
||||
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
|
||||
.stderr_to_stdout()
|
||||
.unchecked()
|
||||
.read()
|
||||
.unwrap_or_else(|err| err.to_string());
|
||||
files.push(("CMD", cmd, output));
|
||||
}
|
||||
|
||||
let local_files = expand_glob_paths(&local_paths, true).await?;
|
||||
for file_path in local_files {
|
||||
if is_image(&file_path) {
|
||||
let contents = read_media_to_data_url(&file_path)
|
||||
.with_context(|| format!("Unable to read media '{file_path}'"))?;
|
||||
data_urls.insert(sha256(&contents), file_path);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
let document = load_file(loaders, &file_path)
|
||||
.await
|
||||
.with_context(|| format!("Unable to read file '{file_path}'"))?;
|
||||
files.push(("FILE", file_path, document.contents));
|
||||
}
|
||||
}
|
||||
|
||||
for file_url in remote_urls {
|
||||
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load url '{file_url}'"))?;
|
||||
if extension == MEDIA_URL_EXTENSION {
|
||||
data_urls.insert(sha256(&contents), file_url);
|
||||
medias.push(contents)
|
||||
} else {
|
||||
files.push(("URL", file_url, contents));
|
||||
}
|
||||
}
|
||||
|
||||
for protocol_path in protocol_paths {
|
||||
let documents = load_protocol_path(loaders, &protocol_path)
|
||||
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
|
||||
files.extend(
|
||||
documents
|
||||
.into_iter()
|
||||
.map(|document| ("FROM", document.path, document.contents)),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((files, medias, data_urls))
|
||||
}
|
||||
|
||||
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
|
||||
if data_url.starts_with("data:") {
|
||||
let hash = sha256(&data_url);
|
||||
if let Some(path) = data_urls.get(&hash) {
|
||||
return path.to_string();
|
||||
}
|
||||
data_url
|
||||
} else {
|
||||
data_url
|
||||
}
|
||||
}
|
||||
|
||||
fn is_image(path: &str) -> bool {
|
||||
get_patch_extension(path)
|
||||
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn read_media_to_data_url(image_path: &str) -> Result<String> {
|
||||
let extension = get_patch_extension(image_path).unwrap_or_default();
|
||||
let mime_type = match extension.as_str() {
|
||||
"png" => "image/png",
|
||||
"jpg" | "jpeg" => "image/jpeg",
|
||||
"webp" => "image/webp",
|
||||
"gif" => "image/gif",
|
||||
_ => bail!("Unexpected media type"),
|
||||
};
|
||||
let mut file = File::open(image_path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let encoded_image = base64_encode(buffer);
|
||||
let data_url = format!("data:{mime_type};base64,{encoded_image}");
|
||||
|
||||
Ok(data_url)
|
||||
}
|
||||
+3034
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,416 @@
|
||||
use super::*;
|
||||
|
||||
use crate::client::{Message, MessageContent, MessageRole, Model};
|
||||
|
||||
use anyhow::Result;
|
||||
use fancy_regex::Regex;
|
||||
use rust_embed::Embed;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub const SHELL_ROLE: &str = "shell";
|
||||
pub const EXPLAIN_SHELL_ROLE: &str = "explain-shell";
|
||||
pub const CODE_ROLE: &str = "code";
|
||||
pub const CREATE_TITLE_ROLE: &str = "create-title";
|
||||
|
||||
pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "assets/roles/"]
|
||||
struct RolesAsset;
|
||||
|
||||
static RE_METADATA: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?s)-{3,}\s*(.*?)\s*-{3,}\s*(.*)").unwrap());
|
||||
|
||||
pub trait RoleLike {
|
||||
fn to_role(&self) -> Role;
|
||||
fn model(&self) -> &Model;
|
||||
fn temperature(&self) -> Option<f64>;
|
||||
fn top_p(&self) -> Option<f64>;
|
||||
fn use_tools(&self) -> Option<String>;
|
||||
fn use_mcp_servers(&self) -> Option<String>;
|
||||
fn set_model(&mut self, model: Model);
|
||||
fn set_temperature(&mut self, value: Option<f64>);
|
||||
fn set_top_p(&mut self, value: Option<f64>);
|
||||
fn set_use_tools(&mut self, value: Option<String>);
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct Role {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
prompt: String,
|
||||
#[serde(
|
||||
rename(serialize = "model", deserialize = "model"),
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
model_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_tools: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_mcp_servers: Option<String>,
|
||||
|
||||
#[serde(skip)]
|
||||
model: Model,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn new(name: &str, content: &str) -> Self {
|
||||
let mut metadata = "";
|
||||
let mut prompt = content.trim();
|
||||
if let Ok(Some(caps)) = RE_METADATA.captures(content) {
|
||||
if let (Some(metadata_value), Some(prompt_value)) = (caps.get(1), caps.get(2)) {
|
||||
metadata = metadata_value.as_str().trim();
|
||||
prompt = prompt_value.as_str().trim();
|
||||
}
|
||||
}
|
||||
let mut prompt = prompt.to_string();
|
||||
interpolate_variables(&mut prompt);
|
||||
let mut role = Self {
|
||||
name: name.to_string(),
|
||||
prompt,
|
||||
..Default::default()
|
||||
};
|
||||
if !metadata.is_empty() {
|
||||
if let Ok(value) = serde_yaml::from_str::<Value>(metadata) {
|
||||
if let Some(value) = value.as_object() {
|
||||
for (key, value) in value {
|
||||
match key.as_str() {
|
||||
"model" => role.model_id = value.as_str().map(|v| v.to_string()),
|
||||
"temperature" => role.temperature = value.as_f64(),
|
||||
"top_p" => role.top_p = value.as_f64(),
|
||||
"use_tools" => role.use_tools = value.as_str().map(|v| v.to_string()),
|
||||
"use_mcp_servers" => {
|
||||
role.use_mcp_servers = value.as_str().map(|v| v.to_string())
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
role
|
||||
}
|
||||
|
||||
pub fn builtin(name: &str) -> Result<Self> {
|
||||
let content = RolesAsset::get(&format!("{name}.md"))
|
||||
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?;
|
||||
let content = unsafe { std::str::from_utf8_unchecked(&content.data) };
|
||||
Ok(Role::new(name, content))
|
||||
}
|
||||
|
||||
pub fn list_builtin_role_names() -> Vec<String> {
|
||||
RolesAsset::iter()
|
||||
.filter_map(|v| v.strip_suffix(".md").map(|v| v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn list_builtin_roles() -> Vec<Self> {
|
||||
RolesAsset::iter()
|
||||
.filter_map(|v| Role::builtin(&v).ok())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn has_args(&self) -> bool {
|
||||
self.name.contains('#')
|
||||
}
|
||||
|
||||
pub fn export(&self) -> String {
|
||||
let mut metadata = vec![];
|
||||
if let Some(model) = self.model_id() {
|
||||
metadata.push(format!("model: {model}"));
|
||||
}
|
||||
if let Some(temperature) = self.temperature() {
|
||||
metadata.push(format!("temperature: {temperature}"));
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
metadata.push(format!("top_p: {top_p}"));
|
||||
}
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
metadata.push(format!("use_tools: {use_tools}"));
|
||||
}
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
metadata.push(format!("use_mcp_servers: {use_mcp_servers}"));
|
||||
}
|
||||
if metadata.is_empty() {
|
||||
format!("{}\n", self.prompt)
|
||||
} else if self.prompt.is_empty() {
|
||||
format!("---\n{}\n---\n", metadata.join("\n"))
|
||||
} else {
|
||||
format!("---\n{}\n---\n\n{}\n", metadata.join("\n"), self.prompt)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save(&mut self, role_name: &str, role_path: &Path, is_repl: bool) -> Result<()> {
|
||||
ensure_parent_exists(role_path)?;
|
||||
|
||||
let content = self.export();
|
||||
std::fs::write(role_path, content).with_context(|| {
|
||||
format!(
|
||||
"Failed to write role {} to {}",
|
||||
self.name,
|
||||
role_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if is_repl {
|
||||
println!("✓ Saved role to '{}'.", role_path.display());
|
||||
}
|
||||
|
||||
if role_name != self.name {
|
||||
self.name = role_name.to_string();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn sync<T: RoleLike>(&mut self, role_like: &T) {
|
||||
let model = role_like.model();
|
||||
let temperature = role_like.temperature();
|
||||
let top_p = role_like.top_p();
|
||||
let use_tools = role_like.use_tools();
|
||||
let use_mcp_servers = role_like.use_mcp_servers();
|
||||
self.batch_set(model, temperature, top_p, use_tools, use_mcp_servers);
|
||||
}
|
||||
|
||||
pub fn batch_set(
|
||||
&mut self,
|
||||
model: &Model,
|
||||
temperature: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
use_tools: Option<String>,
|
||||
use_mcp_servers: Option<String>,
|
||||
) {
|
||||
self.set_model(model.clone());
|
||||
if temperature.is_some() {
|
||||
self.set_temperature(temperature);
|
||||
}
|
||||
if top_p.is_some() {
|
||||
self.set_top_p(top_p);
|
||||
}
|
||||
if use_tools.is_some() {
|
||||
self.set_use_tools(use_tools);
|
||||
}
|
||||
if use_mcp_servers.is_some() {
|
||||
self.set_use_mcp_servers(use_mcp_servers);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_derived(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn model_id(&self) -> Option<&str> {
|
||||
self.model_id.as_deref()
|
||||
}
|
||||
|
||||
pub fn prompt(&self) -> &str {
|
||||
&self.prompt
|
||||
}
|
||||
|
||||
pub fn is_empty_prompt(&self) -> bool {
|
||||
self.prompt.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_embedded_prompt(&self) -> bool {
|
||||
self.prompt.contains(INPUT_PLACEHOLDER)
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, input: &Input) -> String {
|
||||
let input_markdown = input.render();
|
||||
if self.is_empty_prompt() {
|
||||
input_markdown
|
||||
} else if self.is_embedded_prompt() {
|
||||
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
|
||||
} else {
|
||||
format!("{}\n\n{}", self.prompt, input_markdown)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
|
||||
let mut content = input.message_content();
|
||||
let mut messages = if self.is_empty_prompt() {
|
||||
vec![Message::new(MessageRole::User, content)]
|
||||
} else if self.is_embedded_prompt() {
|
||||
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
|
||||
vec![Message::new(MessageRole::User, content)]
|
||||
} else {
|
||||
let mut messages = vec![];
|
||||
let (system, cases) = parse_structure_prompt(&self.prompt);
|
||||
if !system.is_empty() {
|
||||
messages.push(Message::new(
|
||||
MessageRole::System,
|
||||
MessageContent::Text(system.to_string()),
|
||||
));
|
||||
}
|
||||
if !cases.is_empty() {
|
||||
messages.extend(cases.into_iter().flat_map(|(i, o)| {
|
||||
vec![
|
||||
Message::new(MessageRole::User, MessageContent::Text(i.to_string())),
|
||||
Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())),
|
||||
]
|
||||
}));
|
||||
}
|
||||
messages.push(Message::new(MessageRole::User, content));
|
||||
messages
|
||||
};
|
||||
if let Some(text) = input.continue_output() {
|
||||
messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::Text(text.into()),
|
||||
));
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Role {
|
||||
fn to_role(&self) -> Role {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.use_tools.clone()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.use_mcp_servers.clone()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
if !self.model().id().is_empty() {
|
||||
self.model_id = Some(model.id().to_string());
|
||||
}
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
self.temperature = value;
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
self.top_p = value;
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
self.use_tools = value;
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
self.use_mcp_servers = value;
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
|
||||
let mut text = prompt;
|
||||
let mut search_input = true;
|
||||
let mut system = None;
|
||||
let mut parts = vec![];
|
||||
loop {
|
||||
let search = if search_input {
|
||||
"### INPUT:"
|
||||
} else {
|
||||
"### OUTPUT:"
|
||||
};
|
||||
match text.find(search) {
|
||||
Some(idx) => {
|
||||
if system.is_none() {
|
||||
system = Some(&text[..idx])
|
||||
} else {
|
||||
parts.push(&text[..idx])
|
||||
}
|
||||
search_input = !search_input;
|
||||
text = &text[(idx + search.len())..];
|
||||
}
|
||||
None => {
|
||||
if !text.is_empty() {
|
||||
if system.is_none() {
|
||||
system = Some(text)
|
||||
} else {
|
||||
parts.push(text)
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let parts_len = parts.len();
|
||||
if parts_len > 0 && parts_len % 2 == 0 {
|
||||
let cases: Vec<(&str, &str)> = parts
|
||||
.iter()
|
||||
.step_by(2)
|
||||
.zip(parts.iter().skip(1).step_by(2))
|
||||
.map(|(i, o)| (i.trim(), o.trim()))
|
||||
.collect();
|
||||
let system = system.map(|v| v.trim()).unwrap_or_default();
|
||||
return (system, cases);
|
||||
}
|
||||
|
||||
(prompt, vec![])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt1() {
|
||||
let prompt = r#"
|
||||
System message
|
||||
### INPUT:
|
||||
Input 1
|
||||
### OUTPUT:
|
||||
Output 1
|
||||
"#;
|
||||
assert_eq!(
|
||||
parse_structure_prompt(prompt),
|
||||
("System message", vec![("Input 1", "Output 1")])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt2() {
|
||||
let prompt = r#"
|
||||
### INPUT:
|
||||
Input 1
|
||||
### OUTPUT:
|
||||
Output 1
|
||||
"#;
|
||||
assert_eq!(
|
||||
parse_structure_prompt(prompt),
|
||||
("", vec![("Input 1", "Output 1")])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_structure_prompt3() {
|
||||
let prompt = r#"
|
||||
System message
|
||||
### INPUT:
|
||||
Input 1
|
||||
"#;
|
||||
assert_eq!(parse_structure_prompt(prompt), (prompt, vec![]));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,659 @@
|
||||
use super::input::*;
|
||||
use super::*;
|
||||
|
||||
use crate::client::{Message, MessageContent, MessageRole};
|
||||
use crate::render::MarkdownRender;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use inquire::{validator::Validation, Confirm, Text};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{read_to_string, write};
|
||||
use std::path::Path;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static RE_AUTONAME_PREFIX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap());
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct Session {
|
||||
#[serde(rename(serialize = "model", deserialize = "model"))]
|
||||
model_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_tools: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
use_mcp_servers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
save_session: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
compress_threshold: Option<usize>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role_name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
|
||||
agent_variables: AgentVariables,
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
agent_instructions: String,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
compressed_messages: Vec<Message>,
|
||||
messages: Vec<Message>,
|
||||
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
|
||||
data_urls: HashMap<String, String>,
|
||||
|
||||
#[serde(skip)]
|
||||
model: Model,
|
||||
#[serde(skip)]
|
||||
role_prompt: String,
|
||||
#[serde(skip)]
|
||||
name: String,
|
||||
#[serde(skip)]
|
||||
path: Option<String>,
|
||||
#[serde(skip)]
|
||||
dirty: bool,
|
||||
#[serde(skip)]
|
||||
save_session_this_time: bool,
|
||||
#[serde(skip)]
|
||||
compressing: bool,
|
||||
#[serde(skip)]
|
||||
autoname: Option<AutoName>,
|
||||
#[serde(skip)]
|
||||
tokens: usize,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(config: &Config, name: &str) -> Self {
|
||||
let role = config.extract_role();
|
||||
let mut session = Self {
|
||||
name: name.to_string(),
|
||||
save_session: config.save_session,
|
||||
..Default::default()
|
||||
};
|
||||
session.set_role(role);
|
||||
session.dirty = false;
|
||||
session
|
||||
}
|
||||
|
||||
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
|
||||
let content = read_to_string(path)
|
||||
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
|
||||
let mut session: Self =
|
||||
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?;
|
||||
|
||||
session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?;
|
||||
|
||||
if let Some(autoname) = name.strip_prefix("_/") {
|
||||
session.name = TEMP_SESSION_NAME.to_string();
|
||||
session.path = None;
|
||||
if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) {
|
||||
session.autoname = Some(AutoName::new(autoname[16..].to_string()));
|
||||
}
|
||||
} else {
|
||||
session.name = name.to_string();
|
||||
session.path = Some(path.display().to_string());
|
||||
}
|
||||
|
||||
if let Some(role_name) = &session.role_name {
|
||||
if let Ok(role) = config.retrieve_role(role_name) {
|
||||
session.role_prompt = role.prompt().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
session.update_tokens();
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.messages.is_empty() && self.compressed_messages.is_empty()
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn role_name(&self) -> Option<&str> {
|
||||
self.role_name.as_deref()
|
||||
}
|
||||
|
||||
pub fn dirty(&self) -> bool {
|
||||
self.dirty
|
||||
}
|
||||
|
||||
pub fn save_session(&self) -> Option<bool> {
|
||||
self.save_session
|
||||
}
|
||||
|
||||
pub fn tokens(&self) -> usize {
|
||||
self.tokens
|
||||
}
|
||||
|
||||
pub fn update_tokens(&mut self) {
|
||||
self.tokens = self.model().total_tokens(&self.messages);
|
||||
}
|
||||
|
||||
pub fn has_user_messages(&self) -> bool {
|
||||
self.messages.iter().any(|v| v.role.is_user())
|
||||
}
|
||||
|
||||
pub fn user_messages_len(&self) -> usize {
|
||||
self.messages.iter().filter(|v| v.role.is_user()).count()
|
||||
}
|
||||
|
||||
pub fn export(&self) -> Result<String> {
|
||||
let mut data = json!({
|
||||
"path": self.path,
|
||||
"model": self.model().id(),
|
||||
});
|
||||
if let Some(temperature) = self.temperature() {
|
||||
data["temperature"] = temperature.into();
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
data["top_p"] = top_p.into();
|
||||
}
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
data["use_tools"] = use_tools.into();
|
||||
}
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
data["use_mcp_servers"] = use_mcp_servers.into();
|
||||
}
|
||||
if let Some(save_session) = self.save_session() {
|
||||
data["save_session"] = save_session.into();
|
||||
}
|
||||
let (tokens, percent) = self.tokens_usage();
|
||||
data["total_tokens"] = tokens.into();
|
||||
if let Some(max_input_tokens) = self.model().max_input_tokens() {
|
||||
data["max_input_tokens"] = max_input_tokens.into();
|
||||
}
|
||||
if percent != 0.0 {
|
||||
data["total/max"] = format!("{percent}%").into();
|
||||
}
|
||||
data["messages"] = json!(self.messages);
|
||||
|
||||
let output = serde_yaml::to_string(&data)
|
||||
.with_context(|| format!("Unable to show info about session '{}'", &self.name))?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn render(
|
||||
&self,
|
||||
render: &mut MarkdownRender,
|
||||
agent_info: &Option<(String, Vec<String>)>,
|
||||
) -> Result<String> {
|
||||
let mut items = vec![];
|
||||
|
||||
if let Some(path) = &self.path {
|
||||
items.push(("path", path.to_string()));
|
||||
}
|
||||
|
||||
if let Some(autoname) = self.autoname() {
|
||||
items.push(("autoname", autoname.to_string()));
|
||||
}
|
||||
|
||||
items.push(("model", self.model().id()));
|
||||
|
||||
if let Some(temperature) = self.temperature() {
|
||||
items.push(("temperature", temperature.to_string()));
|
||||
}
|
||||
if let Some(top_p) = self.top_p() {
|
||||
items.push(("top_p", top_p.to_string()));
|
||||
}
|
||||
|
||||
if let Some(use_tools) = self.use_tools() {
|
||||
items.push(("use_tools", use_tools));
|
||||
}
|
||||
|
||||
if let Some(use_mcp_servers) = self.use_mcp_servers() {
|
||||
items.push(("use_mcp_servers", use_mcp_servers));
|
||||
}
|
||||
|
||||
if let Some(save_session) = self.save_session() {
|
||||
items.push(("save_session", save_session.to_string()));
|
||||
}
|
||||
|
||||
if let Some(compress_threshold) = self.compress_threshold {
|
||||
items.push(("compress_threshold", compress_threshold.to_string()));
|
||||
}
|
||||
|
||||
if let Some(max_input_tokens) = self.model().max_input_tokens() {
|
||||
items.push(("max_input_tokens", max_input_tokens.to_string()));
|
||||
}
|
||||
|
||||
let mut lines: Vec<String> = items
|
||||
.iter()
|
||||
.map(|(name, value)| format!("{name:<20}{value}"))
|
||||
.collect();
|
||||
|
||||
lines.push(String::new());
|
||||
|
||||
if !self.is_empty() {
|
||||
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
|
||||
|
||||
for message in &self.messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
lines.push(
|
||||
render
|
||||
.render(&message.content.render_input(resolve_url_fn, agent_info)),
|
||||
);
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
if let MessageContent::Text(text) = &message.content {
|
||||
lines.push(render.render(text));
|
||||
}
|
||||
lines.push("".into());
|
||||
}
|
||||
MessageRole::User => {
|
||||
lines.push(format!(
|
||||
">> {}",
|
||||
message.content.render_input(resolve_url_fn, agent_info)
|
||||
));
|
||||
}
|
||||
MessageRole::Tool => {
|
||||
lines.push(message.content.render_input(resolve_url_fn, agent_info));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
pub fn tokens_usage(&self) -> (usize, f32) {
|
||||
let tokens = self.tokens();
|
||||
let max_input_tokens = self.model().max_input_tokens().unwrap_or_default();
|
||||
let percent = if max_input_tokens == 0 {
|
||||
0.0
|
||||
} else {
|
||||
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
|
||||
(percent * 100.0).round() / 100.0
|
||||
};
|
||||
(tokens, percent)
|
||||
}
|
||||
|
||||
pub fn set_role(&mut self, role: Role) {
|
||||
self.model_id = role.model().id();
|
||||
self.temperature = role.temperature();
|
||||
self.top_p = role.top_p();
|
||||
self.use_tools = role.use_tools();
|
||||
self.use_mcp_servers = role.use_mcp_servers();
|
||||
self.model = role.model().clone();
|
||||
self.role_name = convert_option_string(role.name());
|
||||
self.role_prompt = role.prompt().to_string();
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn clear_role(&mut self) {
|
||||
self.role_name = None;
|
||||
self.role_prompt.clear();
|
||||
}
|
||||
|
||||
pub fn sync_agent(&mut self, agent: &Agent) {
|
||||
self.role_name = None;
|
||||
self.role_prompt = agent.interpolated_instructions();
|
||||
self.agent_variables = agent.variables().clone();
|
||||
self.agent_instructions = self.role_prompt.clone();
|
||||
}
|
||||
|
||||
pub fn agent_variables(&self) -> &AgentVariables {
|
||||
&self.agent_variables
|
||||
}
|
||||
|
||||
pub fn agent_instructions(&self) -> &str {
|
||||
&self.agent_instructions
|
||||
}
|
||||
|
||||
pub fn set_save_session(&mut self, value: Option<bool>) {
|
||||
if self.save_session != value {
|
||||
self.save_session = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_save_session_this_time(&mut self) {
|
||||
self.save_session_this_time = true;
|
||||
}
|
||||
|
||||
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
|
||||
if self.compress_threshold != value {
|
||||
self.compress_threshold = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn need_compress(&self, global_compress_threshold: usize) -> bool {
|
||||
if self.compressing {
|
||||
return false;
|
||||
}
|
||||
let threshold = self.compress_threshold.unwrap_or(global_compress_threshold);
|
||||
if threshold < 1 {
|
||||
return false;
|
||||
}
|
||||
self.tokens() > threshold
|
||||
}
|
||||
|
||||
pub fn compressing(&self) -> bool {
|
||||
self.compressing
|
||||
}
|
||||
|
||||
pub fn set_compressing(&mut self, compressing: bool) {
|
||||
self.compressing = compressing;
|
||||
}
|
||||
|
||||
pub fn compress(&mut self, mut prompt: String) {
|
||||
if let Some(system_prompt) = self.messages.first().and_then(|v| {
|
||||
if MessageRole::System == v.role {
|
||||
let content = v.content.to_text();
|
||||
if !content.is_empty() {
|
||||
return Some(content);
|
||||
}
|
||||
}
|
||||
None
|
||||
}) {
|
||||
prompt = format!("{system_prompt}\n\n{prompt}",);
|
||||
}
|
||||
self.compressed_messages.append(&mut self.messages);
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::System,
|
||||
MessageContent::Text(prompt),
|
||||
));
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn need_autoname(&self) -> bool {
|
||||
self.autoname.as_ref().map(|v| v.need()).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn set_autonaming(&mut self, naming: bool) {
|
||||
if let Some(v) = self.autoname.as_mut() {
|
||||
v.naming = naming;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chat_history_for_autonaming(&self) -> Option<String> {
|
||||
self.autoname.as_ref().and_then(|v| v.chat_history.clone())
|
||||
}
|
||||
|
||||
pub fn autoname(&self) -> Option<&str> {
|
||||
self.autoname.as_ref().and_then(|v| v.name.as_deref())
|
||||
}
|
||||
|
||||
pub fn set_autoname(&mut self, value: &str) {
|
||||
let name = value
|
||||
.chars()
|
||||
.map(|v| if v.is_alphanumeric() { v } else { '-' })
|
||||
.collect();
|
||||
self.autoname = Some(AutoName::new(name));
|
||||
}
|
||||
|
||||
pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> {
|
||||
let mut save_session = self.save_session();
|
||||
if self.save_session_this_time {
|
||||
save_session = Some(true);
|
||||
}
|
||||
if self.dirty && save_session != Some(false) {
|
||||
let mut session_dir = session_dir.to_path_buf();
|
||||
let mut session_name = self.name().to_string();
|
||||
if save_session.is_none() {
|
||||
if !is_repl {
|
||||
return Ok(());
|
||||
}
|
||||
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
|
||||
if !ans {
|
||||
return Ok(());
|
||||
}
|
||||
if session_name == TEMP_SESSION_NAME {
|
||||
session_name = Text::new("Session name:")
|
||||
.with_validator(|input: &str| {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
Ok(Validation::Invalid("This name is required".into()))
|
||||
} else if input == TEMP_SESSION_NAME {
|
||||
Ok(Validation::Invalid("This name is reserved".into()))
|
||||
} else {
|
||||
Ok(Validation::Valid)
|
||||
}
|
||||
})
|
||||
.prompt()?;
|
||||
}
|
||||
} else if save_session == Some(true) && session_name == TEMP_SESSION_NAME {
|
||||
session_dir = session_dir.join("_");
|
||||
ensure_parent_exists(&session_dir).with_context(|| {
|
||||
format!("Failed to create directory '{}'", session_dir.display())
|
||||
})?;
|
||||
|
||||
let now = chrono::Local::now();
|
||||
session_name = now.format("%Y%m%dT%H%M%S").to_string();
|
||||
if let Some(autoname) = self.autoname() {
|
||||
session_name = format!("{session_name}-{autoname}")
|
||||
}
|
||||
}
|
||||
let session_path = session_dir.join(format!("{session_name}.yaml"));
|
||||
self.save(&session_name, &session_path, is_repl)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> {
|
||||
ensure_parent_exists(session_path)?;
|
||||
|
||||
self.path = Some(session_path.display().to_string());
|
||||
|
||||
let content = serde_yaml::to_string(&self)
|
||||
.with_context(|| format!("Failed to serde session '{}'", self.name))?;
|
||||
write(session_path, content).with_context(|| {
|
||||
format!(
|
||||
"Failed to write session '{}' to '{}'",
|
||||
self.name,
|
||||
session_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if is_repl {
|
||||
println!("✓ Saved the session to '{}'.", session_path.display());
|
||||
}
|
||||
|
||||
if self.name() != session_name {
|
||||
self.name = session_name.to_string()
|
||||
}
|
||||
|
||||
self.dirty = false;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn guard_empty(&self) -> Result<()> {
|
||||
if !self.is_empty() {
|
||||
bail!("Cannot perform this operation because the session has messages, please `.empty session` first.");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
|
||||
if input.continue_output().is_some() {
|
||||
if let Some(message) = self.messages.last_mut() {
|
||||
if let MessageContent::Text(text) = &mut message.content {
|
||||
*text = format!("{text}{output}");
|
||||
}
|
||||
}
|
||||
} else if input.regenerate() {
|
||||
if let Some(message) = self.messages.last_mut() {
|
||||
if let MessageContent::Text(text) = &mut message.content {
|
||||
*text = output.to_string();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if self.messages.is_empty() {
|
||||
if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) {
|
||||
let raw_input = input.raw();
|
||||
let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n");
|
||||
self.autoname = Some(AutoName::new_from_chat_history(chat_history));
|
||||
}
|
||||
self.messages.extend(input.role().build_messages(input));
|
||||
} else {
|
||||
self.messages
|
||||
.push(Message::new(MessageRole::User, input.message_content()));
|
||||
}
|
||||
self.data_urls.extend(input.data_urls());
|
||||
if let Some(tool_calls) = input.tool_calls() {
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::Tool,
|
||||
MessageContent::ToolCalls(tool_calls.clone()),
|
||||
))
|
||||
}
|
||||
self.messages.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::Text(output.to_string()),
|
||||
));
|
||||
}
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
self.compressed_messages.clear();
|
||||
self.data_urls.clear();
|
||||
self.autoname = None;
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
|
||||
pub fn echo_messages(&self, input: &Input) -> String {
|
||||
let messages = self.build_messages(input);
|
||||
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
|
||||
}
|
||||
|
||||
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
|
||||
let mut messages = self.messages.clone();
|
||||
if input.continue_output().is_some() {
|
||||
return messages;
|
||||
} else if input.regenerate() {
|
||||
while let Some(last) = messages.last() {
|
||||
if !last.role.is_user() {
|
||||
messages.pop();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
let mut need_add_msg = true;
|
||||
let len = messages.len();
|
||||
if len == 0 {
|
||||
messages = input.role().build_messages(input);
|
||||
need_add_msg = false;
|
||||
} else if len == 1 && self.compressed_messages.len() >= 2 {
|
||||
if let Some(index) = self
|
||||
.compressed_messages
|
||||
.iter()
|
||||
.rposition(|v| v.role == MessageRole::User)
|
||||
{
|
||||
messages.extend(self.compressed_messages[index..].to_vec());
|
||||
}
|
||||
}
|
||||
if need_add_msg {
|
||||
messages.push(Message::new(MessageRole::User, input.message_content()));
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl RoleLike for Session {
|
||||
fn to_role(&self) -> Role {
|
||||
let role_name = self.role_name.as_deref().unwrap_or_default();
|
||||
let mut role = Role::new(role_name, &self.role_prompt);
|
||||
role.sync(self);
|
||||
role
|
||||
}
|
||||
|
||||
fn model(&self) -> &Model {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn temperature(&self) -> Option<f64> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn top_p(&self) -> Option<f64> {
|
||||
self.top_p
|
||||
}
|
||||
|
||||
fn use_tools(&self) -> Option<String> {
|
||||
self.use_tools.clone()
|
||||
}
|
||||
|
||||
fn use_mcp_servers(&self) -> Option<String> {
|
||||
self.use_mcp_servers.clone()
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: Model) {
|
||||
if self.model().id() != model.id() {
|
||||
self.model_id = model.id();
|
||||
self.model = model;
|
||||
self.dirty = true;
|
||||
self.update_tokens();
|
||||
}
|
||||
}
|
||||
|
||||
fn set_temperature(&mut self, value: Option<f64>) {
|
||||
if self.temperature != value {
|
||||
self.temperature = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_top_p(&mut self, value: Option<f64>) {
|
||||
if self.top_p != value {
|
||||
self.top_p = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_tools(&mut self, value: Option<String>) {
|
||||
if self.use_tools != value {
|
||||
self.use_tools = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_use_mcp_servers(&mut self, value: Option<String>) {
|
||||
if self.use_mcp_servers != value {
|
||||
self.use_mcp_servers = value;
|
||||
self.dirty = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct AutoName {
|
||||
naming: bool,
|
||||
chat_history: Option<String>,
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
impl AutoName {
|
||||
pub fn new(name: String) -> Self {
|
||||
Self {
|
||||
name: Some(name),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn new_from_chat_history(chat_history: String) -> Self {
|
||||
Self {
|
||||
chat_history: Some(chat_history),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn need(&self) -> bool {
|
||||
!self.naming && self.chat_history.is_some() && self.name.is_none()
|
||||
}
|
||||
}
|
||||
+825
@@ -0,0 +1,825 @@
|
||||
use crate::{
|
||||
config::{Agent, Config, GlobalConfig},
|
||||
utils::*,
|
||||
};
|
||||
|
||||
use crate::mcp::{MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX};
|
||||
use crate::parsers::{bash, python};
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use indexmap::IndexMap;
|
||||
use indoc::formatdoc;
|
||||
use rust_embed::Embed;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
env, fs, io,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use strum_macros::AsRefStr;
|
||||
|
||||
#[derive(Embed)]
|
||||
#[folder = "assets/functions/"]
|
||||
struct FunctionAsset;
|
||||
|
||||
#[cfg(windows)]
|
||||
const PATH_SEP: &str = ";";
|
||||
#[cfg(not(windows))]
|
||||
const PATH_SEP: &str = ":";
|
||||
|
||||
#[derive(AsRefStr)]
|
||||
enum BinaryType {
|
||||
Tool,
|
||||
Agent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)]
|
||||
enum Language {
|
||||
Bash,
|
||||
Python,
|
||||
Javascript,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
impl From<&String> for Language {
|
||||
fn from(s: &String) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"sh" => Language::Bash,
|
||||
"py" => Language::Python,
|
||||
"js" => Language::Javascript,
|
||||
_ => Language::Unsupported,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(windows), expect(dead_code))]
|
||||
impl Language {
|
||||
fn to_cmd(self) -> &'static str {
|
||||
match self {
|
||||
Language::Bash => "bash",
|
||||
Language::Python => "python",
|
||||
Language::Javascript => "node",
|
||||
Language::Unsupported => "sh",
|
||||
}
|
||||
}
|
||||
|
||||
fn to_extension(self) -> &'static str {
|
||||
match self {
|
||||
Language::Bash => "sh",
|
||||
Language::Python => "py",
|
||||
Language::Javascript => "js",
|
||||
_ => "sh",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn eval_tool_calls(
|
||||
config: &GlobalConfig,
|
||||
mut calls: Vec<ToolCall>,
|
||||
) -> Result<Vec<ToolResult>> {
|
||||
let mut output = vec![];
|
||||
if calls.is_empty() {
|
||||
return Ok(output);
|
||||
}
|
||||
calls = ToolCall::dedup(calls);
|
||||
if calls.is_empty() {
|
||||
bail!("The request was aborted because an infinite loop of function calls was detected.")
|
||||
}
|
||||
let mut is_all_null = true;
|
||||
for call in calls {
|
||||
let mut result = call.eval(config).await?;
|
||||
if result.is_null() {
|
||||
result = json!("DONE");
|
||||
} else {
|
||||
is_all_null = false;
|
||||
}
|
||||
output.push(ToolResult::new(call, result));
|
||||
}
|
||||
if is_all_null {
|
||||
output = vec![];
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolResult {
|
||||
pub call: ToolCall,
|
||||
pub output: Value,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
pub fn new(call: ToolCall, output: Value) -> Self {
|
||||
Self { call, output }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Functions {
|
||||
declarations: Vec<FunctionDeclaration>,
|
||||
}
|
||||
|
||||
impl Functions {
|
||||
pub fn init() -> Result<Self> {
|
||||
info!(
|
||||
"Initializing global functions from {}",
|
||||
Config::global_tools_file().display()
|
||||
);
|
||||
|
||||
let declarations = Self {
|
||||
declarations: Self::build_global_tool_declarations_from_path(
|
||||
&Config::global_tools_file(),
|
||||
)?,
|
||||
};
|
||||
|
||||
info!(
|
||||
"Building global function binaries in {}",
|
||||
Config::functions_bin_dir().display()
|
||||
);
|
||||
Self::build_global_function_binaries_from_path(Config::global_tools_file())?;
|
||||
|
||||
Ok(declarations)
|
||||
}
|
||||
|
||||
pub fn init_agent(name: &str, global_tools: &[String]) -> Result<Self> {
|
||||
let global_tools_declarations = if !global_tools.is_empty() {
|
||||
let enabled_tools = global_tools.join("\n");
|
||||
info!("Loading global tools for agent: {name}: {enabled_tools}");
|
||||
let tools_declarations = Self::build_global_tool_declarations(&enabled_tools)?;
|
||||
|
||||
info!(
|
||||
"Building global function binaries required by agent: {name} in {}",
|
||||
Config::functions_bin_dir().display()
|
||||
);
|
||||
Self::build_global_function_binaries(&enabled_tools)?;
|
||||
tools_declarations
|
||||
} else {
|
||||
debug!("No global tools found for agent: {}", name);
|
||||
Vec::new()
|
||||
};
|
||||
let agent_script_declarations = match Config::agent_functions_file(name) {
|
||||
Ok(path) if path.exists() => {
|
||||
info!(
|
||||
"Loading functions script for agent: {name} from {}",
|
||||
path.display()
|
||||
);
|
||||
let script_declarations = Self::generate_declarations(&path)?;
|
||||
debug!("agent_declarations: {:#?}", script_declarations);
|
||||
|
||||
info!(
|
||||
"Building function binary for agent: {name} in {}",
|
||||
Config::agent_bin_dir(name).display()
|
||||
);
|
||||
Self::build_agent_tool_binaries(name)?;
|
||||
script_declarations
|
||||
}
|
||||
_ => {
|
||||
debug!("No functions script found for agent: {}", name);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
let declarations = [global_tools_declarations, agent_script_declarations].concat();
|
||||
|
||||
Ok(Self { declarations })
|
||||
}
|
||||
|
||||
pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> {
|
||||
self.declarations.iter().find(|v| v.name == name)
|
||||
}
|
||||
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.declarations.iter().any(|v| v.name == name)
|
||||
}
|
||||
|
||||
pub fn declarations(&self) -> &[FunctionDeclaration] {
|
||||
&self.declarations
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.declarations.is_empty()
|
||||
}
|
||||
|
||||
pub fn has_mcp_functions(&self) -> bool {
|
||||
self.declarations.iter().any(|d| {
|
||||
d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX)
|
||||
|| d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_mcp_meta_functions(&mut self) {
|
||||
self.declarations.retain(|d| {
|
||||
!d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX)
|
||||
&& !d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX)
|
||||
});
|
||||
}
|
||||
|
||||
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
|
||||
let mut invoke_function_properties = IndexMap::new();
|
||||
invoke_function_properties.insert(
|
||||
"server".to_string(),
|
||||
JsonSchema {
|
||||
type_value: Some("string".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
invoke_function_properties.insert(
|
||||
"tool".to_string(),
|
||||
JsonSchema {
|
||||
type_value: Some("string".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
invoke_function_properties.insert(
|
||||
"arguments".to_string(),
|
||||
JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
for server in mcp_servers {
|
||||
let invoke_function_name = format!("{}_{server}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX);
|
||||
let invoke_function_declaration = FunctionDeclaration {
|
||||
name: invoke_function_name.clone(),
|
||||
description: formatdoc!(
|
||||
r#"
|
||||
Invoke the specified tool on the {server} MCP server. Always call {invoke_function_name} first to find the
|
||||
correct names of tools before calling '{invoke_function_name}'.
|
||||
"#
|
||||
),
|
||||
parameters: JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
properties: Some(invoke_function_properties.clone()),
|
||||
required: Some(vec!["server".to_string(), "tool".to_string()]),
|
||||
..Default::default()
|
||||
},
|
||||
agent: false,
|
||||
};
|
||||
let list_functions_declaration = FunctionDeclaration {
|
||||
name: format!("{}_{}", MCP_LIST_META_FUNCTION_NAME_PREFIX, server),
|
||||
description: format!("List all the available tools for the {server} MCP server"),
|
||||
parameters: JsonSchema::default(),
|
||||
agent: false,
|
||||
};
|
||||
self.declarations.push(invoke_function_declaration);
|
||||
self.declarations.push(list_functions_declaration);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_global_tool_declarations(enabled_tools: &str) -> Result<Vec<FunctionDeclaration>> {
|
||||
let global_tools_directory = Config::global_tools_dir();
|
||||
let mut function_declarations = Vec::new();
|
||||
|
||||
for line in enabled_tools.lines() {
|
||||
if line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let declaration = Self::generate_declarations(&global_tools_directory.join(line))?;
|
||||
function_declarations.extend(declaration);
|
||||
}
|
||||
|
||||
Ok(function_declarations)
|
||||
}
|
||||
|
||||
fn build_global_tool_declarations_from_path(
|
||||
tools_txt_path: &PathBuf,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let enabled_tools = fs::read_to_string(tools_txt_path)
|
||||
.with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?;
|
||||
|
||||
Self::build_global_tool_declarations(&enabled_tools)
|
||||
}
|
||||
|
||||
fn generate_declarations(tools_file_path: &Path) -> Result<Vec<FunctionDeclaration>> {
|
||||
info!(
|
||||
"Loading tool definitions from {}",
|
||||
tools_file_path.display()
|
||||
);
|
||||
let file_name = tools_file_path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.ok_or_else(|| {
|
||||
anyhow::format_err!("Unable to extract file name from path: {tools_file_path:?}")
|
||||
})?;
|
||||
|
||||
match File::open(tools_file_path) {
|
||||
Ok(tool_file) => {
|
||||
let language = Language::from(
|
||||
&tools_file_path
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.map(|s| s.to_lowercase())
|
||||
.ok_or_else(|| {
|
||||
anyhow!("Unable to extract language from tool file: {file_name}")
|
||||
})?,
|
||||
);
|
||||
|
||||
match language {
|
||||
Language::Bash => {
|
||||
bash::generate_bash_declarations(tool_file, tools_file_path, file_name)
|
||||
}
|
||||
Language::Python => python::generate_python_declarations(
|
||||
tool_file,
|
||||
file_name,
|
||||
tools_file_path.parent(),
|
||||
),
|
||||
Language::Unsupported => {
|
||||
bail!("Unsupported tool file extension: {}", language.as_ref())
|
||||
}
|
||||
_ => bail!("Unsupported tool language: {}", language.as_ref()),
|
||||
}
|
||||
}
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => {
|
||||
bail!(
|
||||
"Tool definition file not found: {}",
|
||||
tools_file_path.display()
|
||||
);
|
||||
}
|
||||
Err(err) => bail!("Unable to open tool definition file. {}", err),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_global_function_binaries(enabled_tools: &str) -> Result<()> {
|
||||
let bin_dir = Config::functions_bin_dir();
|
||||
if !bin_dir.exists() {
|
||||
fs::create_dir_all(&bin_dir)?;
|
||||
}
|
||||
info!(
|
||||
"Clearing existing function binaries in {}",
|
||||
bin_dir.display()
|
||||
);
|
||||
clear_dir(&bin_dir)?;
|
||||
|
||||
for line in enabled_tools.lines() {
|
||||
if line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let language = Language::from(
|
||||
&Path::new(line)
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.map(|s| s.to_lowercase())
|
||||
.ok_or_else(|| {
|
||||
anyhow::format_err!("Unable to extract file extension from path: {line:?}")
|
||||
})?,
|
||||
);
|
||||
let binary_name = Path::new(line)
|
||||
.file_stem()
|
||||
.and_then(OsStr::to_str)
|
||||
.ok_or_else(|| {
|
||||
anyhow::format_err!("Unable to extract file name from path: {line:?}")
|
||||
})?;
|
||||
|
||||
if language == Language::Unsupported {
|
||||
bail!("Unsupported tool file extension: {}", language.as_ref());
|
||||
}
|
||||
|
||||
Self::build_binaries(binary_name, language, BinaryType::Tool)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_global_function_binaries_from_path(tools_txt_path: PathBuf) -> Result<()> {
|
||||
let enabled_tools = fs::read_to_string(&tools_txt_path)
|
||||
.with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?;
|
||||
|
||||
Self::build_global_function_binaries(&enabled_tools)
|
||||
}
|
||||
|
||||
fn build_agent_tool_binaries(name: &str) -> Result<()> {
|
||||
let agent_bin_directory = Config::agent_bin_dir(name);
|
||||
if !agent_bin_directory.exists() {
|
||||
debug!(
|
||||
"Creating agent bin directory: {}",
|
||||
agent_bin_directory.display()
|
||||
);
|
||||
fs::create_dir_all(&agent_bin_directory)?;
|
||||
} else {
|
||||
debug!(
|
||||
"Clearing existing agent bin directory: {}",
|
||||
agent_bin_directory.display()
|
||||
);
|
||||
clear_dir(&agent_bin_directory)?;
|
||||
}
|
||||
|
||||
let language = Language::from(
|
||||
&Config::agent_functions_file(name)?
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.map(|s| s.to_lowercase())
|
||||
.ok_or_else(|| {
|
||||
anyhow::format_err!("Unable to extract file extension from path: {name:?}")
|
||||
})?,
|
||||
);
|
||||
|
||||
if language == Language::Unsupported {
|
||||
bail!("Unsupported tool file extension: {}", language.as_ref());
|
||||
}
|
||||
|
||||
Self::build_binaries(name, language, BinaryType::Agent)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn build_binaries(
|
||||
binary_name: &str,
|
||||
language: Language,
|
||||
binary_type: BinaryType,
|
||||
) -> Result<()> {
|
||||
use native::runtime;
|
||||
let (binary_file, binary_script_file) = match binary_type {
|
||||
BinaryType::Tool => (
|
||||
Config::functions_bin_dir().join(format!("{binary_name}.cmd")),
|
||||
Config::functions_bin_dir()
|
||||
.join(format!("run-{binary_name}.{}", language.to_extension())),
|
||||
),
|
||||
BinaryType::Agent => (
|
||||
Config::agent_bin_dir(binary_name).join(format!("{binary_name}.cmd")),
|
||||
Config::agent_bin_dir(binary_name)
|
||||
.join(format!("run-{binary_name}.{}", language.to_extension())),
|
||||
),
|
||||
};
|
||||
info!(
|
||||
"Building binary runner for function: {} ({})",
|
||||
binary_name,
|
||||
binary_script_file.display(),
|
||||
);
|
||||
let embedded_file = FunctionAsset::get(&format!(
|
||||
"scripts/run-{}.{}",
|
||||
binary_type.as_ref().to_lowercase(),
|
||||
language.to_extension()
|
||||
))
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"Failed to load embedded script for run-{}.{}",
|
||||
binary_type.as_ref().to_lowercase(),
|
||||
language.to_extension()
|
||||
)
|
||||
})?;
|
||||
let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let content = match binary_type {
|
||||
BinaryType::Tool => content_template.replace("{function_name}", binary_name),
|
||||
BinaryType::Agent => content_template.replace("{agent_name}", binary_name),
|
||||
}
|
||||
.replace("{config_dir}", &Config::config_dir().to_string_lossy());
|
||||
if binary_script_file.exists() {
|
||||
fs::remove_file(&binary_script_file)?;
|
||||
}
|
||||
let mut script_file = File::create(&binary_script_file)?;
|
||||
script_file.write_all(content.as_bytes())?;
|
||||
|
||||
info!(
|
||||
"Building binary for function: {} ({})",
|
||||
binary_name,
|
||||
binary_file.display()
|
||||
);
|
||||
|
||||
let run = match language {
|
||||
Language::Bash => {
|
||||
let shell = runtime::bash_path().ok_or_else(|| anyhow!("Shell not found"))?;
|
||||
format!("{shell} --noprofile --norc")
|
||||
}
|
||||
Language::Python if Path::new(".venv").exists() => {
|
||||
let executable_path = env::current_dir()?
|
||||
.join(".venv")
|
||||
.join("Scripts")
|
||||
.join("activate.bat");
|
||||
let canonicalized_path = fs::canonicalize(&executable_path)?;
|
||||
format!(
|
||||
"call \"{}\" && {}",
|
||||
canonicalized_path.to_string_lossy(),
|
||||
language.to_cmd()
|
||||
)
|
||||
}
|
||||
Language::Javascript => runtime::which(language.to_cmd())
|
||||
.ok_or_else(|| anyhow!("Unable to find {} in PATH", language.to_cmd()))?,
|
||||
_ => bail!("Unsupported language: {}", language.as_ref()),
|
||||
};
|
||||
let bin_dir = binary_file
|
||||
.parent()
|
||||
.expect("Failed to get parent directory of binary file")
|
||||
.canonicalize()?
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let wrapper_binary = binary_script_file
|
||||
.canonicalize()?
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let content = formatdoc!(
|
||||
r#"
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
set "bin_dir={bin_dir}"
|
||||
|
||||
{run} "{wrapper_binary}" %*"#,
|
||||
);
|
||||
|
||||
let mut file = File::create(&binary_file)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn build_binaries(
|
||||
binary_name: &str,
|
||||
language: Language,
|
||||
binary_type: BinaryType,
|
||||
) -> Result<()> {
|
||||
use std::os::unix::prelude::PermissionsExt;
|
||||
|
||||
let binary_file = match binary_type {
|
||||
BinaryType::Tool => Config::functions_bin_dir().join(binary_name),
|
||||
BinaryType::Agent => Config::agent_bin_dir(binary_name).join(binary_name),
|
||||
};
|
||||
info!(
|
||||
"Building binary for function: {} ({})",
|
||||
binary_name,
|
||||
binary_file.display()
|
||||
);
|
||||
let embedded_file = FunctionAsset::get(&format!(
|
||||
"scripts/run-{}.{}",
|
||||
binary_type.as_ref().to_lowercase(),
|
||||
language.to_extension()
|
||||
))
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"Failed to load embedded script for run-{}.{}",
|
||||
binary_type.as_ref().to_lowercase(),
|
||||
language.to_extension()
|
||||
)
|
||||
})?;
|
||||
let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
|
||||
let content = match binary_type {
|
||||
BinaryType::Tool => content_template.replace("{function_name}", binary_name),
|
||||
BinaryType::Agent => content_template.replace("{agent_name}", binary_name),
|
||||
}
|
||||
.replace("{config_dir}", &Config::config_dir().to_string_lossy());
|
||||
if binary_file.exists() {
|
||||
fs::remove_file(&binary_file)?;
|
||||
}
|
||||
let mut file = File::create(&binary_file)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
fs::set_permissions(&binary_file, fs::Permissions::from_mode(0o755))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: JsonSchema,
|
||||
#[serde(skip_serializing, default)]
|
||||
pub agent: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct JsonSchema {
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub type_value: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<IndexMap<String, JsonSchema>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub items: Option<Box<JsonSchema>>,
|
||||
#[serde(rename = "anyOf", skip_serializing_if = "Option::is_none")]
|
||||
pub any_of: Option<Vec<JsonSchema>>,
|
||||
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
|
||||
pub enum_value: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl JsonSchema {
|
||||
pub fn is_empty_properties(&self) -> bool {
|
||||
match &self.properties {
|
||||
Some(v) => v.is_empty(),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
type CallConfig = (String, String, Vec<String>, HashMap<String, String>);
|
||||
|
||||
impl ToolCall {
|
||||
pub fn dedup(calls: Vec<Self>) -> Vec<Self> {
|
||||
let mut new_calls = vec![];
|
||||
let mut seen_ids = HashSet::new();
|
||||
|
||||
for call in calls.into_iter().rev() {
|
||||
if let Some(id) = &call.id {
|
||||
if !seen_ids.contains(id) {
|
||||
seen_ids.insert(id.clone());
|
||||
new_calls.push(call);
|
||||
}
|
||||
} else {
|
||||
new_calls.push(call);
|
||||
}
|
||||
}
|
||||
|
||||
new_calls.reverse();
|
||||
new_calls
|
||||
}
|
||||
|
||||
pub fn new(name: String, arguments: Value, id: Option<String>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
arguments,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn eval(&self, config: &GlobalConfig) -> Result<Value> {
|
||||
let (call_name, cmd_name, mut cmd_args, envs) = match &config.read().agent {
|
||||
Some(agent) => self.extract_call_config_from_agent(config, agent)?,
|
||||
None => self.extract_call_config_from_config(config)?,
|
||||
};
|
||||
|
||||
let json_data = if self.arguments.is_object() {
|
||||
self.arguments.clone()
|
||||
} else if let Some(arguments) = self.arguments.as_str() {
|
||||
let arguments: Value = serde_json::from_str(arguments).map_err(|_| {
|
||||
anyhow!("The call '{call_name}' has invalid arguments: {arguments}")
|
||||
})?;
|
||||
arguments
|
||||
} else {
|
||||
bail!(
|
||||
"The call '{call_name}' has invalid arguments: {}",
|
||||
self.arguments
|
||||
);
|
||||
};
|
||||
|
||||
cmd_args.push(json_data.to_string());
|
||||
|
||||
let prompt = format!("Call {cmd_name} {}", cmd_args.join(" "));
|
||||
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
println!("{}", dimmed_text(&prompt));
|
||||
}
|
||||
|
||||
let output = match cmd_name.as_str() {
|
||||
_ if cmd_name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) => {
|
||||
let registry_arc = {
|
||||
let cfg = config.read();
|
||||
cfg.mcp_registry
|
||||
.clone()
|
||||
.with_context(|| "MCP is not configured")?
|
||||
};
|
||||
|
||||
registry_arc.catalog().await?
|
||||
}
|
||||
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
|
||||
let server = json_data
|
||||
.get("server")
|
||||
.ok_or_else(|| anyhow!("Missing 'server' in arguments"))?
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Invalid 'server' in arguments"))?;
|
||||
let tool = json_data
|
||||
.get("tool")
|
||||
.ok_or_else(|| anyhow!("Missing 'tool' in arguments"))?
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?;
|
||||
let arguments = json_data
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({}));
|
||||
let registry_arc = {
|
||||
let cfg = config.read();
|
||||
cfg.mcp_registry
|
||||
.clone()
|
||||
.with_context(|| "MCP is not configured")?
|
||||
};
|
||||
let result = registry_arc.invoke(server, tool, arguments).await?;
|
||||
serde_json::to_value(result)?
|
||||
}
|
||||
_ => match run_llm_function(cmd_name, cmd_args, envs)? {
|
||||
Some(contents) => serde_json::from_str(&contents)
|
||||
.ok()
|
||||
.unwrap_or_else(|| json!({"output": contents})),
|
||||
None => Value::Null,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn extract_call_config_from_agent(
|
||||
&self,
|
||||
config: &GlobalConfig,
|
||||
agent: &Agent,
|
||||
) -> Result<CallConfig> {
|
||||
let function_name = self.name.clone();
|
||||
match agent.functions().find(&function_name) {
|
||||
Some(function) => {
|
||||
let agent_name = agent.name().to_string();
|
||||
if function.agent {
|
||||
Ok((
|
||||
format!("{agent_name}-{function_name}"),
|
||||
agent_name,
|
||||
vec![function_name],
|
||||
agent.variable_envs(),
|
||||
))
|
||||
} else {
|
||||
Ok((
|
||||
function_name.clone(),
|
||||
function_name,
|
||||
vec![],
|
||||
Default::default(),
|
||||
))
|
||||
}
|
||||
}
|
||||
None => self.extract_call_config_from_config(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_call_config_from_config(&self, config: &GlobalConfig) -> Result<CallConfig> {
|
||||
let function_name = self.name.clone();
|
||||
match config.read().functions.contains(&function_name) {
|
||||
true => Ok((
|
||||
function_name.clone(),
|
||||
function_name,
|
||||
vec![],
|
||||
Default::default(),
|
||||
)),
|
||||
false => bail!("Unexpected call: {function_name} {}", self.arguments),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_llm_function(
|
||||
cmd_name: String,
|
||||
cmd_args: Vec<String>,
|
||||
mut envs: HashMap<String, String>,
|
||||
) -> Result<Option<String>> {
|
||||
let mut bin_dirs: Vec<PathBuf> = vec![];
|
||||
if cmd_args.len() > 1 {
|
||||
let dir = Config::agent_bin_dir(&cmd_name);
|
||||
if dir.exists() {
|
||||
bin_dirs.push(dir);
|
||||
}
|
||||
}
|
||||
bin_dirs.push(Config::functions_bin_dir());
|
||||
let current_path = env::var("PATH").context("No PATH environment variable")?;
|
||||
let prepend_path = bin_dirs
|
||||
.iter()
|
||||
.map(|v| format!("{}{PATH_SEP}", v.display()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
envs.insert("PATH".into(), format!("{prepend_path}{current_path}"));
|
||||
|
||||
let temp_file = temp_file("-eval-", "");
|
||||
envs.insert("LLM_OUTPUT".into(), temp_file.display().to_string());
|
||||
|
||||
#[cfg(windows)]
|
||||
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
|
||||
|
||||
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
|
||||
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
|
||||
if exit_code != 0 {
|
||||
bail!("Tool call exited with {exit_code}");
|
||||
}
|
||||
let mut output = None;
|
||||
if temp_file.exists() {
|
||||
let contents =
|
||||
fs::read_to_string(temp_file).context("Failed to retrieve tool call output")?;
|
||||
if !contents.is_empty() {
|
||||
debug!("Tool {cmd_name} output: {}", contents);
|
||||
output = Some(contents);
|
||||
}
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn polyfill_cmd_name<T: AsRef<Path>>(cmd_name: &str, bin_dir: &[T]) -> String {
|
||||
let cmd_name = cmd_name.to_string();
|
||||
if let Ok(exts) = env::var("PATHEXT") {
|
||||
for name in exts.split(';').map(|ext| format!("{cmd_name}{ext}")) {
|
||||
for dir in bin_dir {
|
||||
let path = dir.as_ref().join(&name);
|
||||
if path.exists() {
|
||||
return name.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd_name
|
||||
}
|
||||
+496
@@ -0,0 +1,496 @@
|
||||
mod cli;
|
||||
mod client;
|
||||
mod config;
|
||||
mod function;
|
||||
mod rag;
|
||||
mod render;
|
||||
mod repl;
|
||||
mod serve;
|
||||
#[macro_use]
|
||||
mod utils;
|
||||
mod mcp;
|
||||
mod parsers;
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
use crate::cli::Cli;
|
||||
use crate::client::{
|
||||
call_chat_completions, call_chat_completions_streaming, list_models, ModelType,
|
||||
};
|
||||
use crate::config::{
|
||||
ensure_parent_exists, list_agents, load_env_file, macro_execute, Agent, Config, GlobalConfig,
|
||||
Input, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, TEMP_SESSION_NAME,
|
||||
};
|
||||
use crate::render::render_error;
|
||||
use crate::repl::Repl;
|
||||
use crate::utils::*;
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use clap::{CommandFactory, Parser};
|
||||
use clap_complete::CompleteEnv;
|
||||
use inquire::Text;
|
||||
use log::LevelFilter;
|
||||
use log4rs::append::console::ConsoleAppender;
|
||||
use log4rs::append::file::FileAppender;
|
||||
use log4rs::config::{Appender, Logger, Root};
|
||||
use log4rs::encode::pattern::PatternEncoder;
|
||||
use parking_lot::RwLock;
|
||||
use std::path::PathBuf;
|
||||
use std::{env, mem, process, sync::Arc};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
load_env_file()?;
|
||||
CompleteEnv::with_factory(Cli::command).complete();
|
||||
let cli = Cli::parse();
|
||||
|
||||
if cli.tail_logs {
|
||||
tail_logs(cli.disable_log_colors).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let text = cli.text()?;
|
||||
let working_mode = if cli.serve.is_some() {
|
||||
WorkingMode::Serve
|
||||
} else if text.is_none() && cli.file.is_empty() {
|
||||
WorkingMode::Repl
|
||||
} else {
|
||||
WorkingMode::Cmd
|
||||
};
|
||||
|
||||
let info_flag = cli.info
|
||||
|| cli.sync_models
|
||||
|| cli.list_models
|
||||
|| cli.list_roles
|
||||
|| cli.list_agents
|
||||
|| cli.list_rags
|
||||
|| cli.list_macros
|
||||
|| cli.list_sessions;
|
||||
let log_path = setup_logger(working_mode.is_serve())?;
|
||||
let abort_signal = create_abort_signal();
|
||||
let start_mcp_servers = cli.agent.is_none() && cli.role.is_none();
|
||||
let config = Arc::new(RwLock::new(
|
||||
Config::init(
|
||||
working_mode,
|
||||
info_flag,
|
||||
start_mcp_servers,
|
||||
log_path,
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?,
|
||||
));
|
||||
if let Err(err) = run(config, cli, text, abort_signal).await {
|
||||
render_error(err);
|
||||
process::exit(1);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run(
|
||||
config: GlobalConfig,
|
||||
cli: Cli,
|
||||
text: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
if cli.sync_models {
|
||||
let url = config.read().sync_models_url();
|
||||
return Config::sync_models(&url, abort_signal.clone()).await;
|
||||
}
|
||||
|
||||
if cli.list_models {
|
||||
for model in list_models(&config.read(), ModelType::Chat) {
|
||||
println!("{}", model.id());
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_roles {
|
||||
let roles = Config::list_roles(true).join("\n");
|
||||
println!("{roles}");
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_agents {
|
||||
let agents = list_agents().join("\n");
|
||||
println!("{agents}");
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_rags {
|
||||
let rags = Config::list_rags().join("\n");
|
||||
println!("{rags}");
|
||||
return Ok(());
|
||||
}
|
||||
if cli.list_macros {
|
||||
let macros = Config::list_macros().join("\n");
|
||||
println!("{macros}");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if cli.dry_run {
|
||||
config.write().dry_run = true;
|
||||
}
|
||||
|
||||
if let Some(agent) = &cli.agent {
|
||||
if cli.build_tools {
|
||||
info!("Building tools for agent '{agent}'...");
|
||||
Agent::init(&config, agent, abort_signal.clone()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let session = cli.session.as_ref().map(|v| match v {
|
||||
Some(v) => v.as_str(),
|
||||
None => TEMP_SESSION_NAME,
|
||||
});
|
||||
if !cli.agent_variable.is_empty() {
|
||||
config.write().agent_variables = Some(
|
||||
cli.agent_variable
|
||||
.chunks(2)
|
||||
.map(|v| (v[0].to_string(), v[1].to_string()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
let ret = Config::use_agent(&config, agent, session, abort_signal.clone()).await;
|
||||
config.write().agent_variables = None;
|
||||
ret?;
|
||||
} else {
|
||||
if let Some(prompt) = &cli.prompt {
|
||||
config.write().use_prompt(prompt)?;
|
||||
} else if let Some(name) = &cli.role {
|
||||
Config::use_role_safely(&config, name, abort_signal.clone()).await?;
|
||||
} else if cli.execute {
|
||||
Config::use_role_safely(&config, SHELL_ROLE, abort_signal.clone()).await?;
|
||||
} else if cli.code {
|
||||
Config::use_role_safely(&config, CODE_ROLE, abort_signal.clone()).await?;
|
||||
}
|
||||
if let Some(session) = &cli.session {
|
||||
config
|
||||
.write()
|
||||
.use_session(session.as_ref().map(|v| v.as_str()))?;
|
||||
}
|
||||
if let Some(rag) = &cli.rag {
|
||||
Config::use_rag(&config, Some(rag), abort_signal.clone()).await?;
|
||||
}
|
||||
}
|
||||
|
||||
if cli.build_tools {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if cli.list_sessions {
|
||||
let sessions = config.read().list_sessions().join("\n");
|
||||
println!("{sessions}");
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(model_id) = &cli.model {
|
||||
config.write().set_model(model_id)?;
|
||||
}
|
||||
if cli.no_stream {
|
||||
config.write().stream = false;
|
||||
}
|
||||
if cli.empty_session {
|
||||
config.write().empty_session()?;
|
||||
}
|
||||
if cli.save_session {
|
||||
config.write().set_save_session_this_time()?;
|
||||
}
|
||||
if cli.info {
|
||||
let info = config.read().info()?;
|
||||
println!("{info}");
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(addr) = cli.serve {
|
||||
return serve::run(config, addr).await;
|
||||
}
|
||||
let is_repl = config.read().working_mode.is_repl();
|
||||
if cli.rebuild_rag {
|
||||
Config::rebuild_rag(&config, abort_signal.clone()).await?;
|
||||
if is_repl {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
if let Some(name) = &cli.macro_name {
|
||||
macro_execute(&config, name, text.as_deref(), abort_signal.clone()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
if cli.execute && !is_repl {
|
||||
let input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
|
||||
shell_execute(&config, &SHELL, input, abort_signal.clone()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
apply_prelude_safely(&config, abort_signal.clone()).await?;
|
||||
|
||||
match is_repl {
|
||||
false => {
|
||||
let mut input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
|
||||
input.use_embeddings(abort_signal.clone()).await?;
|
||||
start_directive(&config, input, cli.code, abort_signal).await
|
||||
}
|
||||
true => {
|
||||
if !*IS_STDOUT_TERMINAL {
|
||||
bail!("No TTY for REPL")
|
||||
}
|
||||
start_interactive(&config).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_prelude_safely(config: &RwLock<Config>, abort_signal: AbortSignal) -> Result<()> {
|
||||
let mut cfg = {
|
||||
let mut guard = config.write();
|
||||
mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
cfg.apply_prelude(abort_signal.clone()).await?;
|
||||
|
||||
{
|
||||
let mut guard = config.write();
|
||||
*guard = cfg;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn start_directive(
|
||||
config: &GlobalConfig,
|
||||
input: Input,
|
||||
code_mode: bool,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let client = input.create_client()?;
|
||||
let extract_code = !*IS_STDOUT_TERMINAL && code_mode;
|
||||
config.write().before_chat_completion(&input)?;
|
||||
let (output, tool_results) = if !input.stream() || extract_code {
|
||||
call_chat_completions(
|
||||
&input,
|
||||
true,
|
||||
extract_code,
|
||||
client.as_ref(),
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
|
||||
};
|
||||
config
|
||||
.write()
|
||||
.after_chat_completion(&input, &output, &tool_results)?;
|
||||
|
||||
if !tool_results.is_empty() {
|
||||
start_directive(
|
||||
config,
|
||||
input.merge_tool_results(output, tool_results),
|
||||
code_mode,
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
config.write().exit_session()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_interactive(config: &GlobalConfig) -> Result<()> {
|
||||
let mut repl: Repl = Repl::init(config)?;
|
||||
repl.run().await
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn shell_execute(
|
||||
config: &GlobalConfig,
|
||||
shell: &Shell,
|
||||
mut input: Input,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let client = input.create_client()?;
|
||||
config.write().before_chat_completion(&input)?;
|
||||
let (eval_str, _) =
|
||||
call_chat_completions(&input, false, true, client.as_ref(), abort_signal.clone()).await?;
|
||||
|
||||
config
|
||||
.write()
|
||||
.after_chat_completion(&input, &eval_str, &[])?;
|
||||
if eval_str.is_empty() {
|
||||
bail!("No command generated");
|
||||
}
|
||||
if config.read().dry_run {
|
||||
config.read().print_markdown(&eval_str)?;
|
||||
return Ok(());
|
||||
}
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
let options = ["execute", "revise", "describe", "copy", "quit"];
|
||||
let command = color_text(eval_str.trim(), nu_ansi_term::Color::Rgb(255, 165, 0));
|
||||
let first_letter_color = nu_ansi_term::Color::Cyan;
|
||||
let prompt_text = options
|
||||
.iter()
|
||||
.map(|v| format!("{}{}", color_text(&v[0..1], first_letter_color), &v[1..]))
|
||||
.collect::<Vec<String>>()
|
||||
.join(&dimmed_text(" | "));
|
||||
loop {
|
||||
println!("{command}");
|
||||
let answer_char =
|
||||
read_single_key(&['e', 'r', 'd', 'c', 'q'], 'e', &format!("{prompt_text}: "))?;
|
||||
|
||||
match answer_char {
|
||||
'e' => {
|
||||
debug!("{} {:?}", shell.cmd, &[&shell.arg, &eval_str]);
|
||||
let code = run_command(&shell.cmd, &[&shell.arg, &eval_str], None)?;
|
||||
if code == 0 && config.read().save_shell_history {
|
||||
let _ = append_to_shell_history(&shell.name, &eval_str, code);
|
||||
}
|
||||
process::exit(code);
|
||||
}
|
||||
'r' => {
|
||||
let revision = Text::new("Enter your revision:").prompt()?;
|
||||
let text = format!("{}\n{revision}", input.text());
|
||||
input.set_text(text);
|
||||
return shell_execute(config, shell, input, abort_signal.clone()).await;
|
||||
}
|
||||
'd' => {
|
||||
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
|
||||
let input = Input::from_str(config, &eval_str, Some(role));
|
||||
if input.stream() {
|
||||
call_chat_completions_streaming(
|
||||
&input,
|
||||
client.as_ref(),
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
call_chat_completions(
|
||||
&input,
|
||||
true,
|
||||
false,
|
||||
client.as_ref(),
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
'c' => {
|
||||
set_text(&eval_str)?;
|
||||
println!("{}", dimmed_text("✓ Copied the command."));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
println!("{eval_str}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_input(
|
||||
config: &GlobalConfig,
|
||||
text: Option<String>,
|
||||
file: &[String],
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Input> {
|
||||
let input = if file.is_empty() {
|
||||
Input::from_str(config, &text.unwrap_or_default(), None)
|
||||
} else {
|
||||
Input::from_files_with_spinner(
|
||||
config,
|
||||
&text.unwrap_or_default(),
|
||||
file.to_vec(),
|
||||
None,
|
||||
abort_signal,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
if input.is_empty() {
|
||||
bail!("No input");
|
||||
}
|
||||
Ok(input)
|
||||
}
|
||||
|
||||
fn setup_logger(is_serve: bool) -> Result<Option<PathBuf>> {
|
||||
let (log_level, log_path) = Config::log_config(is_serve)?;
|
||||
if log_level == LevelFilter::Off {
|
||||
return Ok(None);
|
||||
}
|
||||
let encoder = Box::new(PatternEncoder::new(
|
||||
"{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}",
|
||||
));
|
||||
let log_filter = match env::var(get_env_name("log_filter")) {
|
||||
Ok(v) => Some(v),
|
||||
Err(_) => match is_serve {
|
||||
true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))),
|
||||
false => None,
|
||||
},
|
||||
};
|
||||
|
||||
match log_path.clone() {
|
||||
None => {
|
||||
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
|
||||
log4rs::init_config(init_console_logger(log_level, log_filter, console_appender))?;
|
||||
}
|
||||
Some(path) => {
|
||||
ensure_parent_exists(&path)?;
|
||||
let file_appender = FileAppender::builder().encoder(encoder.clone()).build(path);
|
||||
|
||||
match file_appender {
|
||||
Ok(appender) => {
|
||||
log4rs::init_config(init_file_logger(log_level, log_filter, appender))?
|
||||
}
|
||||
Err(_) => {
|
||||
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
|
||||
log4rs::init_config(init_console_logger(
|
||||
log_level,
|
||||
log_filter,
|
||||
console_appender,
|
||||
))?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(log_path)
|
||||
}
|
||||
|
||||
fn init_file_logger(
|
||||
log_level: LevelFilter,
|
||||
log_filter: Option<String>,
|
||||
file_appender: FileAppender,
|
||||
) -> log4rs::Config {
|
||||
let root_log_level = if log_filter.is_some() {
|
||||
LevelFilter::Off
|
||||
} else {
|
||||
log_level
|
||||
};
|
||||
let mut config_builder = log4rs::Config::builder()
|
||||
.appender(Appender::builder().build("logfile", Box::new(file_appender)));
|
||||
|
||||
if let Some(filter) = log_filter {
|
||||
config_builder = config_builder.logger(Logger::builder().build(filter, log_level));
|
||||
}
|
||||
|
||||
config_builder
|
||||
.build(Root::builder().appender("logfile").build(root_log_level))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn init_console_logger(
|
||||
log_level: LevelFilter,
|
||||
log_filter: Option<String>,
|
||||
console_appender: ConsoleAppender,
|
||||
) -> log4rs::Config {
|
||||
let root_log_level = if log_filter.is_some() {
|
||||
LevelFilter::Off
|
||||
} else {
|
||||
log_level
|
||||
};
|
||||
let mut config_builder = log4rs::Config::builder()
|
||||
.appender(Appender::builder().build("console", Box::new(console_appender)));
|
||||
|
||||
if let Some(filter) = log_filter {
|
||||
config_builder = config_builder.logger(Logger::builder().build(filter, log_level));
|
||||
}
|
||||
|
||||
config_builder
|
||||
.build(Root::builder().appender("console").build(root_log_level))
|
||||
.unwrap()
|
||||
}
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
use crate::config::Config;
|
||||
use crate::utils::{abortable_run_with_spinner, AbortSignal};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures_util::future::BoxFuture;
|
||||
use futures_util::{stream, StreamExt, TryStreamExt};
|
||||
use rmcp::model::{CallToolRequestParam, CallToolResult};
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::transport::TokioChildProcess;
|
||||
use rmcp::{RoleClient, ServiceExt};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs::OpenOptions;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use tokio::process::Command;
|
||||
|
||||
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
||||
pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list";
|
||||
|
||||
type ConnectedServer = RunningService<RoleClient, ()>;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServersConfig {
|
||||
#[serde(rename = "mcpServers")]
|
||||
mcp_servers: HashMap<String, McpServer>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServer {
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, JsonField>>,
|
||||
cwd: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum JsonField {
|
||||
Str(String),
|
||||
Bool(bool),
|
||||
Int(i64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpRegistry {
|
||||
log_path: Option<PathBuf>,
|
||||
config: Option<McpServersConfig>,
|
||||
servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>,
|
||||
}
|
||||
|
||||
impl McpRegistry {
|
||||
pub async fn init(
|
||||
log_path: Option<PathBuf>,
|
||||
start_mcp_servers: bool,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
let mut registry = Self {
|
||||
log_path,
|
||||
..Default::default()
|
||||
};
|
||||
if !Config::mcp_config_file().try_exists().with_context(|| {
|
||||
format!(
|
||||
"Failed to check MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
})? {
|
||||
debug!(
|
||||
"MCP config file does not exist at {}, skipping MCP initialization",
|
||||
Config::mcp_config_file().display()
|
||||
);
|
||||
return Ok(registry);
|
||||
}
|
||||
let err = || {
|
||||
format!(
|
||||
"Failed to load MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
};
|
||||
let content = tokio::fs::read_to_string(Config::mcp_config_file())
|
||||
.await
|
||||
.with_context(err)?;
|
||||
let config: McpServersConfig = serde_json::from_str(&content).with_context(err)?;
|
||||
registry.config = Some(config);
|
||||
|
||||
if start_mcp_servers {
|
||||
abortable_run_with_spinner(
|
||||
registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
pub async fn reinit(
|
||||
registry: McpRegistry,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
debug!("Reinitializing MCP registry");
|
||||
debug!("Stopping all MCP servers");
|
||||
let mut new_registry = abortable_run_with_spinner(
|
||||
registry.stop_all_servers(),
|
||||
"Stopping MCP servers",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
abortable_run_with_spinner(
|
||||
new_registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(new_registry)
|
||||
}
|
||||
|
||||
async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> {
|
||||
if self.config.is_none() {
|
||||
debug!("MCP config is not present; assuming MCP servers are disabled globally. skipping MCP initialization");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!("Starting selected MCP servers: {:?}", use_mcp_servers);
|
||||
|
||||
if let Some(servers) = use_mcp_servers {
|
||||
let config = self
|
||||
.config
|
||||
.as_ref()
|
||||
.with_context(|| "MCP Config not defined. Cannot start servers")?;
|
||||
let mcp_servers = config.mcp_servers.clone();
|
||||
|
||||
let enabled_servers: HashSet<String> =
|
||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||
let server_ids: Vec<String> = if servers == "all" {
|
||||
mcp_servers.into_keys().collect()
|
||||
} else {
|
||||
mcp_servers
|
||||
.into_keys()
|
||||
.filter(|id| enabled_servers.contains(id))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let results: Vec<(String, Arc<_>)> = stream::iter(
|
||||
server_ids
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
)
|
||||
.buffer_unordered(num_cpus::get())
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
self.servers = results.into_iter().collect();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> {
|
||||
let server = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
let mut cmd = Command::new(&server.command);
|
||||
if let Some(args) = &server.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
if let Some(env) = &server.env {
|
||||
let env: HashMap<String, String> = env
|
||||
.iter()
|
||||
.map(|(k, v)| match v {
|
||||
JsonField::Str(s) => (k.clone(), s.clone()),
|
||||
JsonField::Bool(b) => (k.clone(), b.to_string()),
|
||||
JsonField::Int(i) => (k.clone(), i.to_string()),
|
||||
})
|
||||
.collect();
|
||||
cmd.envs(env);
|
||||
}
|
||||
if let Some(cwd) = &server.cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
|
||||
let transport = if let Some(log_path) = self.log_path.as_ref() {
|
||||
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
|
||||
|
||||
let log_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_path)?;
|
||||
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
|
||||
transport
|
||||
} else {
|
||||
TokioChildProcess::new(cmd)?
|
||||
};
|
||||
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
|
||||
);
|
||||
debug!(
|
||||
"Available tools for MCP server {id}: {:?}",
|
||||
service.list_tools(None).await?
|
||||
);
|
||||
|
||||
info!("Started MCP server: {id}");
|
||||
|
||||
Ok((id.to_string(), service))
|
||||
}
|
||||
|
||||
pub async fn stop_all_servers(mut self) -> Result<Self> {
|
||||
for (id, server) in self.servers {
|
||||
Arc::try_unwrap(server)
|
||||
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
|
||||
.cancel()
|
||||
.await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
}
|
||||
|
||||
self.servers = HashMap::new();
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn list_servers(&self) -> Vec<String> {
|
||||
self.servers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> {
|
||||
let servers: Vec<(String, Arc<ConnectedServer>)> = self
|
||||
.servers
|
||||
.iter()
|
||||
.map(|(id, s)| (id.clone(), s.clone()))
|
||||
.collect();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut out = Vec::with_capacity(servers.len());
|
||||
for (id, server) in servers {
|
||||
let tools = server.list_tools(None).await?;
|
||||
let resources = server.list_resources(None).await.unwrap_or_default();
|
||||
// TODO implement prompt sampling for MCP servers
|
||||
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
|
||||
out.push(json!({
|
||||
"server": id,
|
||||
"tools": tools,
|
||||
"resources": resources,
|
||||
}));
|
||||
}
|
||||
Ok(Value::Array(out))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn invoke(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> BoxFuture<'static, Result<CallToolResult>> {
|
||||
let server = self
|
||||
.servers
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("Invoked MCP server does not exist: {server}"));
|
||||
|
||||
let tool = tool.to_owned();
|
||||
Box::pin(async move {
|
||||
let server = server?;
|
||||
let call_tool_request = CallToolRequestParam {
|
||||
name: Cow::Owned(tool.to_owned()),
|
||||
arguments: arguments.as_object().cloned(),
|
||||
};
|
||||
|
||||
let result = server.call_tool(call_tool_request).await?;
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
use crate::function::{FunctionDeclaration, JsonSchema};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use argc::{ChoiceValue, CommandValue, FlagOptionValue};
|
||||
use indexmap::IndexMap;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::{env, fs};
|
||||
|
||||
pub fn generate_bash_declarations(
|
||||
mut tool_file: File,
|
||||
tools_file_path: &Path,
|
||||
file_name: &str,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let mut src = String::new();
|
||||
tool_file
|
||||
.read_to_string(&mut src)
|
||||
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
|
||||
|
||||
debug!("Building script at '{tool_file:?}'");
|
||||
let build_script = argc::build(
|
||||
&src,
|
||||
"",
|
||||
env::var("TERM_WIDTH").ok().and_then(|v| v.parse().ok()),
|
||||
)?;
|
||||
fs::write(tools_file_path, &build_script)
|
||||
.with_context(|| format!("Failed to write built script to '{tools_file_path:?}'"))?;
|
||||
|
||||
let command_value = argc::export(&build_script, file_name)
|
||||
.with_context(|| format!("Failed to parse script at '{tool_file:?}'"))?;
|
||||
if command_value.subcommands.is_empty() {
|
||||
let function_declaration =
|
||||
command_to_function_declaration(&command_value).ok_or_else(|| {
|
||||
anyhow::format_err!("Tool definition missing or empty description: {file_name}")
|
||||
})?;
|
||||
Ok(vec![function_declaration])
|
||||
} else {
|
||||
let mut declarations = vec![];
|
||||
for subcommand in &command_value.subcommands {
|
||||
if subcommand.name.starts_with('_') && subcommand.name != "_instructions" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(mut function_declaration) = command_to_function_declaration(subcommand) {
|
||||
function_declaration.agent = true;
|
||||
declarations.push(function_declaration);
|
||||
} else {
|
||||
bail!(
|
||||
"Tool definition missing or empty description: {} {}",
|
||||
file_name,
|
||||
subcommand.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(declarations)
|
||||
}
|
||||
}
|
||||
|
||||
fn command_to_function_declaration(cmd: &CommandValue) -> Option<FunctionDeclaration> {
|
||||
if cmd.describe.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(FunctionDeclaration {
|
||||
name: underscore(&cmd.name),
|
||||
description: cmd.describe.clone(),
|
||||
parameters: parse_parameters_schema(&cmd.flag_options),
|
||||
agent: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn underscore(s: &str) -> String {
|
||||
s.replace('-', "_")
|
||||
}
|
||||
|
||||
fn schema_ty(t: &str) -> JsonSchema {
|
||||
JsonSchema {
|
||||
type_value: Some(t.to_string()),
|
||||
description: None,
|
||||
properties: None,
|
||||
items: None,
|
||||
any_of: None,
|
||||
enum_value: None,
|
||||
default: None,
|
||||
required: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_description(mut schema: JsonSchema, describe: &str) -> JsonSchema {
|
||||
if !describe.is_empty() {
|
||||
schema.description = Some(describe.to_string());
|
||||
}
|
||||
schema
|
||||
}
|
||||
|
||||
fn with_enum(mut schema: JsonSchema, choice: &Option<ChoiceValue>) -> JsonSchema {
|
||||
if let Some(ChoiceValue::Values(values)) = choice {
|
||||
if !values.is_empty() {
|
||||
schema.enum_value = Some(values.clone());
|
||||
}
|
||||
}
|
||||
schema
|
||||
}
|
||||
|
||||
fn parse_property(flag: &FlagOptionValue) -> JsonSchema {
|
||||
let mut schema = if flag.flag {
|
||||
schema_ty("boolean")
|
||||
} else if flag.multiple_occurs {
|
||||
let mut arr = schema_ty("array");
|
||||
arr.items = Some(Box::new(schema_ty("string")));
|
||||
arr
|
||||
} else if flag.notations.first().map(|s| s.as_str()) == Some("INT") {
|
||||
schema_ty("integer")
|
||||
} else if flag.notations.first().map(|s| s.as_str()) == Some("NUM") {
|
||||
schema_ty("number")
|
||||
} else {
|
||||
schema_ty("string")
|
||||
};
|
||||
|
||||
schema = with_description(schema, &flag.describe);
|
||||
schema = with_enum(schema, &flag.choice);
|
||||
schema
|
||||
}
|
||||
|
||||
fn parse_parameters_schema(flags: &[FlagOptionValue]) -> JsonSchema {
|
||||
let filtered = flags.iter().filter(|f| f.id != "help" && f.id != "version");
|
||||
let mut props: IndexMap<String, JsonSchema> = IndexMap::new();
|
||||
let mut required: Vec<String> = Vec::new();
|
||||
|
||||
for f in filtered {
|
||||
let key = underscore(&f.id);
|
||||
if f.required {
|
||||
required.push(key.clone());
|
||||
}
|
||||
props.insert(key, parse_property(f));
|
||||
}
|
||||
|
||||
JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
description: None,
|
||||
properties: Some(props),
|
||||
items: None,
|
||||
any_of: None,
|
||||
enum_value: None,
|
||||
default: None,
|
||||
required: Some(required),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub(crate) mod bash;
|
||||
pub(crate) mod python;
|
||||
@@ -0,0 +1,420 @@
|
||||
use crate::function::{FunctionDeclaration, JsonSchema};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use ast::{Stmt, StmtFunctionDef};
|
||||
use indexmap::IndexMap;
|
||||
use rustpython_ast::{Constant, Expr, UnaryOp};
|
||||
use rustpython_parser::{ast, Mode};
|
||||
use serde_json::Value;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Param {
|
||||
name: String,
|
||||
ty_hint: String,
|
||||
required: bool,
|
||||
default: Option<Value>,
|
||||
doc_type: Option<String>,
|
||||
doc_desc: Option<String>,
|
||||
}
|
||||
|
||||
pub fn generate_python_declarations(
|
||||
mut tool_file: File,
|
||||
file_name: &str,
|
||||
parent: Option<&Path>,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let mut src = String::new();
|
||||
tool_file
|
||||
.read_to_string(&mut src)
|
||||
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
|
||||
let suite = parse_suite(&src, file_name)?;
|
||||
|
||||
let is_tool = parent
|
||||
.and_then(|p| p.file_name())
|
||||
.is_some_and(|n| n == "tools");
|
||||
let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?;
|
||||
|
||||
if is_tool {
|
||||
for d in &mut declarations {
|
||||
d.agent = true;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(declarations)
|
||||
}
|
||||
|
||||
fn parse_suite(src: &str, filename: &str) -> Result<ast::Suite> {
|
||||
let mod_ast =
|
||||
rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?;
|
||||
|
||||
let suite = match mod_ast {
|
||||
ast::Mod::Module(m) => m.body,
|
||||
ast::Mod::Interactive(m) => m.body,
|
||||
ast::Mod::Expression(_) => bail!("expected a module; got a single expression"),
|
||||
_ => bail!("unexpected parse mode/AST variant"),
|
||||
};
|
||||
|
||||
Ok(suite)
|
||||
}
|
||||
|
||||
fn python_to_function_declarations(
|
||||
file_name: &str,
|
||||
module: &ast::Suite,
|
||||
is_tool: bool,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let mut out = Vec::new();
|
||||
|
||||
for stmt in module {
|
||||
if let Stmt::FunctionDef(fd) = stmt {
|
||||
let func_name = fd.name.to_string();
|
||||
|
||||
if func_name.starts_with('_') && func_name != "_instructions" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_tool && func_name != "run" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let description = get_docstring_from_body(&fd.body).unwrap_or_default();
|
||||
let params = collect_params(fd);
|
||||
let schema = build_parameters_schema(¶ms, &description);
|
||||
let name = if is_tool && func_name == "run" {
|
||||
underscore(file_name)
|
||||
} else {
|
||||
underscore(&func_name)
|
||||
};
|
||||
let desc_trim = description.trim().to_string();
|
||||
if desc_trim.is_empty() {
|
||||
bail!("Missing or empty description on function: {func_name}");
|
||||
}
|
||||
|
||||
out.push(FunctionDeclaration {
|
||||
name,
|
||||
description: desc_trim,
|
||||
parameters: schema,
|
||||
agent: !is_tool,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn get_docstring_from_body(body: &[Stmt]) -> Option<String> {
|
||||
let first = body.first()?;
|
||||
if let Stmt::Expr(expr_stmt) = first {
|
||||
if let Expr::Constant(constant) = &*expr_stmt.value {
|
||||
if let Constant::Str(s) = &constant.value {
|
||||
return Some(s.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
|
||||
let mut out = Vec::new();
|
||||
|
||||
for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) {
|
||||
let name = a.def.arg.to_string();
|
||||
let mut ty = get_arg_type(a.def.annotation.as_deref());
|
||||
let mut required = a.default.is_none();
|
||||
|
||||
if ty.ends_with('?') {
|
||||
ty.pop();
|
||||
required = false;
|
||||
}
|
||||
|
||||
let default = if a.default.is_some() {
|
||||
Some(Value::Null)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required,
|
||||
default,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
for a in &fd.args.kwonlyargs {
|
||||
let name = a.def.arg.to_string();
|
||||
let mut ty = get_arg_type(a.def.annotation.as_deref());
|
||||
let mut required = a.default.is_none();
|
||||
|
||||
if ty.ends_with('?') {
|
||||
ty.pop();
|
||||
required = false;
|
||||
}
|
||||
|
||||
let default = if a.default.is_some() {
|
||||
Some(Value::Null)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required,
|
||||
default,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(vararg) = &fd.args.vararg {
|
||||
let name = vararg.arg.to_string();
|
||||
let inner = get_arg_type(vararg.annotation.as_deref());
|
||||
let ty = if inner.is_empty() {
|
||||
"list[str]".into()
|
||||
} else {
|
||||
format!("list[{inner}]")
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required: false,
|
||||
default: None,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &fd.args.kwarg {
|
||||
let name = kwarg.arg.to_string();
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: "object".into(),
|
||||
required: false,
|
||||
default: None,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(doc) = get_docstring_from_body(&fd.body) {
|
||||
let meta = parse_docstring_args(&doc);
|
||||
for p in &mut out {
|
||||
if let Some((t, d)) = meta.get(&p.name) {
|
||||
if !t.is_empty() {
|
||||
p.doc_type = Some(t.clone());
|
||||
}
|
||||
|
||||
if !d.is_empty() {
|
||||
p.doc_desc = Some(d.clone());
|
||||
}
|
||||
|
||||
if t.ends_with('?') {
|
||||
p.required = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn get_arg_type(annotation: Option<&Expr>) -> String {
|
||||
match annotation {
|
||||
None => "".to_string(),
|
||||
Some(Expr::Name(n)) => n.id.to_string(),
|
||||
Some(Expr::Subscript(sub)) => match &*sub.value {
|
||||
Expr::Name(name) if &name.id == "Optional" => {
|
||||
let inner = get_arg_type(Some(&sub.slice));
|
||||
format!("{inner}?")
|
||||
}
|
||||
Expr::Name(name) if &name.id == "List" => {
|
||||
let inner = get_arg_type(Some(&sub.slice));
|
||||
format!("list[{inner}]")
|
||||
}
|
||||
Expr::Name(name) if &name.id == "Literal" => {
|
||||
let vals = literal_members(&sub.slice);
|
||||
format!("literal:{}", vals.join("|"))
|
||||
}
|
||||
_ => "any".to_string(),
|
||||
},
|
||||
_ => "any".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_to_str(e: &Expr) -> String {
|
||||
match e {
|
||||
Expr::Constant(c) => match &c.value {
|
||||
Constant::Str(s) => s.clone(),
|
||||
Constant::Int(i) => i.to_string(),
|
||||
Constant::Float(f) => f.to_string(),
|
||||
Constant::Bool(b) => b.to_string(),
|
||||
Constant::None => "None".to_string(),
|
||||
Constant::Ellipsis => "...".to_string(),
|
||||
Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(),
|
||||
Constant::Complex { real, imag } => format!("{real}+{imag}j"),
|
||||
_ => "any".to_string(),
|
||||
},
|
||||
|
||||
Expr::Name(n) => n.id.to_string(),
|
||||
|
||||
Expr::UnaryOp(u) => {
|
||||
if matches!(u.op, UnaryOp::USub) {
|
||||
let inner = expr_to_str(&u.operand);
|
||||
if inner.parse::<f64>().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) {
|
||||
return format!("-{inner}");
|
||||
}
|
||||
}
|
||||
"any".to_string()
|
||||
}
|
||||
|
||||
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::<Vec<_>>().join(","),
|
||||
|
||||
_ => "any".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn literal_members(e: &Expr) -> Vec<String> {
|
||||
match e {
|
||||
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(),
|
||||
_ => vec![expr_to_str(e)],
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_docstring_args(doc: &str) -> IndexMap<String, (String, String)> {
|
||||
let mut out = IndexMap::new();
|
||||
let mut in_args = false;
|
||||
for line in doc.lines() {
|
||||
if !in_args {
|
||||
if line.trim_start().starts_with("Args:") {
|
||||
in_args = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if !(line.starts_with(' ') || line.starts_with('\t')) {
|
||||
break;
|
||||
}
|
||||
let s = line.trim();
|
||||
if let Some((left, desc)) = s.split_once(':') {
|
||||
let left = left.trim();
|
||||
let mut name = left.to_string();
|
||||
let mut ty = String::new();
|
||||
if let Some((n, t)) = left.split_once(' ') {
|
||||
name = n.trim().to_string();
|
||||
ty = t.trim().to_string();
|
||||
if ty.starts_with('(') && ty.ends_with(')') {
|
||||
let mut inner = ty[1..ty.len() - 1].to_string();
|
||||
if inner.to_lowercase().contains("optional") && !inner.ends_with('?') {
|
||||
inner.push('?');
|
||||
}
|
||||
ty = inner;
|
||||
}
|
||||
}
|
||||
out.insert(name, (ty, desc.trim().to_string()));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn underscore(s: &str) -> String {
|
||||
s.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() {
|
||||
c.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
.split('_')
|
||||
.filter(|t| !t.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("_")
|
||||
}
|
||||
|
||||
fn build_parameters_schema(params: &[Param], _description: &str) -> JsonSchema {
|
||||
let mut props: IndexMap<String, JsonSchema> = IndexMap::new();
|
||||
let mut req: Vec<String> = Vec::new();
|
||||
|
||||
for p in params {
|
||||
let name = p.name.replace('-', "_");
|
||||
let mut schema = JsonSchema::default();
|
||||
|
||||
let ty = if !p.ty_hint.is_empty() {
|
||||
p.ty_hint.as_str()
|
||||
} else if let Some(t) = &p.doc_type {
|
||||
t.as_str()
|
||||
} else {
|
||||
"str"
|
||||
};
|
||||
|
||||
if let Some(d) = &p.doc_desc {
|
||||
if !d.is_empty() {
|
||||
schema.description = Some(d.clone());
|
||||
}
|
||||
}
|
||||
|
||||
apply_type_to_schema(ty, &mut schema);
|
||||
|
||||
if p.default.is_none() && p.required {
|
||||
req.push(name.clone());
|
||||
}
|
||||
|
||||
props.insert(name, schema);
|
||||
}
|
||||
|
||||
JsonSchema {
|
||||
type_value: Some("object".into()),
|
||||
description: None,
|
||||
properties: Some(props),
|
||||
items: None,
|
||||
any_of: None,
|
||||
enum_value: None,
|
||||
default: None,
|
||||
required: if req.is_empty() { None } else { Some(req) },
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) {
|
||||
let t = ty.trim_end_matches('?');
|
||||
if let Some(rest) = t.strip_prefix("list[") {
|
||||
s.type_value = Some("array".into());
|
||||
let inner = rest.trim_end_matches(']');
|
||||
let mut item = JsonSchema::default();
|
||||
|
||||
apply_type_to_schema(inner, &mut item);
|
||||
|
||||
if item.type_value.is_none() {
|
||||
item.type_value = Some("string".into());
|
||||
}
|
||||
s.items = Some(Box::new(item));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(rest) = t.strip_prefix("literal:") {
|
||||
s.type_value = Some("string".into());
|
||||
let vals = rest
|
||||
.split('|')
|
||||
.map(|x| x.trim().trim_matches('"').trim_matches('\'').to_string())
|
||||
.collect::<Vec<_>>();
|
||||
if !vals.is_empty() {
|
||||
s.enum_value = Some(vals);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
s.type_value = Some(
|
||||
match t {
|
||||
"bool" => "boolean",
|
||||
"int" => "integer",
|
||||
"float" => "number",
|
||||
"str" | "any" | "" => "string",
|
||||
_ => "string",
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
+1013
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,66 @@
|
||||
use super::*;
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use serde::{de, Deserializer, Serializer};
|
||||
|
||||
pub fn serialize<S>(
|
||||
vectors: &IndexMap<DocumentId, Vec<f32>>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let encoded_map: IndexMap<String, String> = vectors
|
||||
.iter()
|
||||
.map(|(id, vec)| {
|
||||
let (h, l) = id.split();
|
||||
let byte_slice = unsafe {
|
||||
std::slice::from_raw_parts(vec.as_ptr() as *const u8, vec.len() * size_of::<f32>())
|
||||
};
|
||||
(format!("{h}-{l}"), STANDARD.encode(byte_slice))
|
||||
})
|
||||
.collect();
|
||||
|
||||
encoded_map.serialize(serializer)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<IndexMap<DocumentId, Vec<f32>>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let encoded_map: IndexMap<String, String> =
|
||||
IndexMap::<String, String>::deserialize(deserializer)?;
|
||||
|
||||
let mut decoded_map = IndexMap::new();
|
||||
for (key, base64_str) in encoded_map {
|
||||
let decoded_key: DocumentId = key
|
||||
.split_once('-')
|
||||
.and_then(|(h, l)| {
|
||||
let h = h.parse::<usize>().ok()?;
|
||||
let l = l.parse::<usize>().ok()?;
|
||||
Some(DocumentId::new(h, l))
|
||||
})
|
||||
.ok_or_else(|| de::Error::custom(format!("Invalid key '{key}'")))?;
|
||||
|
||||
let decoded_data = STANDARD.decode(&base64_str).map_err(de::Error::custom)?;
|
||||
|
||||
if decoded_data.len() % size_of::<f32>() != 0 {
|
||||
return Err(de::Error::custom(format!("Invalid vector at '{key}'")));
|
||||
}
|
||||
|
||||
let num_f32s = decoded_data.len() / size_of::<f32>();
|
||||
|
||||
let mut vec_f32 = vec![0.0f32; num_f32s];
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(
|
||||
decoded_data.as_ptr(),
|
||||
vec_f32.as_mut_ptr() as *mut u8,
|
||||
decoded_data.len(),
|
||||
);
|
||||
}
|
||||
|
||||
decoded_map.insert(decoded_key, vec_f32);
|
||||
}
|
||||
|
||||
Ok(decoded_map)
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
#[derive(PartialEq, Eq, Hash)]
|
||||
pub enum Language {
|
||||
Cpp,
|
||||
Go,
|
||||
Java,
|
||||
Js,
|
||||
Php,
|
||||
Proto,
|
||||
Python,
|
||||
Rst,
|
||||
Ruby,
|
||||
Rust,
|
||||
Scala,
|
||||
Swift,
|
||||
Markdown,
|
||||
Latex,
|
||||
Html,
|
||||
Sol,
|
||||
}
|
||||
|
||||
impl Language {
|
||||
pub fn separators(&self) -> Vec<&str> {
|
||||
match self {
|
||||
Language::Cpp => vec![
|
||||
"\nclass ",
|
||||
"\nvoid ",
|
||||
"\nint ",
|
||||
"\nfloat ",
|
||||
"\ndouble ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Go => vec![
|
||||
"\nfunc ",
|
||||
"\nvar ",
|
||||
"\nconst ",
|
||||
"\ntype ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Java => vec![
|
||||
"\nclass ",
|
||||
"\npublic ",
|
||||
"\nprotected ",
|
||||
"\nprivate ",
|
||||
"\nstatic ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Js => vec![
|
||||
"\nfunction ",
|
||||
"\nconst ",
|
||||
"\nlet ",
|
||||
"\nvar ",
|
||||
"\nclass ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\ndefault ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Php => vec![
|
||||
"\nfunction ",
|
||||
"\nclass ",
|
||||
"\nif ",
|
||||
"\nforeach ",
|
||||
"\nwhile ",
|
||||
"\ndo ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Proto => vec![
|
||||
"\nmessage ",
|
||||
"\nservice ",
|
||||
"\nenum ",
|
||||
"\noption ",
|
||||
"\nimport ",
|
||||
"\nsyntax ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""],
|
||||
Language::Rst => vec![
|
||||
"\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "",
|
||||
],
|
||||
Language::Ruby => vec![
|
||||
"\ndef ",
|
||||
"\nclass ",
|
||||
"\nif ",
|
||||
"\nunless ",
|
||||
"\nwhile ",
|
||||
"\nfor ",
|
||||
"\ndo ",
|
||||
"\nbegin ",
|
||||
"\nrescue ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Rust => vec![
|
||||
"\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ",
|
||||
"\nmatch ", "\nconst ", "\n\n", "\n", " ", "",
|
||||
],
|
||||
Language::Scala => vec![
|
||||
"\nclass ",
|
||||
"\nobject ",
|
||||
"\ndef ",
|
||||
"\nval ",
|
||||
"\nvar ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\nmatch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Swift => vec![
|
||||
"\nfunc ",
|
||||
"\nclass ",
|
||||
"\nstruct ",
|
||||
"\nenum ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\ndo ",
|
||||
"\nswitch ",
|
||||
"\ncase ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Markdown => vec![
|
||||
"\n## ",
|
||||
"\n### ",
|
||||
"\n#### ",
|
||||
"\n##### ",
|
||||
"\n###### ",
|
||||
"```\n\n",
|
||||
"\n\n***\n\n",
|
||||
"\n\n---\n\n",
|
||||
"\n\n___\n\n",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Latex => vec![
|
||||
"\n\\chapter{",
|
||||
"\n\\section{",
|
||||
"\n\\subsection{",
|
||||
"\n\\subsubsection{",
|
||||
"\n\\begin{enumerate}",
|
||||
"\n\\begin{itemize}",
|
||||
"\n\\begin{description}",
|
||||
"\n\\begin{list}",
|
||||
"\n\\begin{quote}",
|
||||
"\n\\begin{quotation}",
|
||||
"\n\\begin{verse}",
|
||||
"\n\\begin{verbatim}",
|
||||
"\n\\begin{align}",
|
||||
"$$",
|
||||
"$",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
Language::Html => vec![
|
||||
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>",
|
||||
"<h6>", "<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>",
|
||||
"<footer>", "<nav>", "<head>", "<style>", "<script>", "<meta>", "<title>", " ", "",
|
||||
],
|
||||
Language::Sol => vec![
|
||||
"\npragma ",
|
||||
"\nusing ",
|
||||
"\ncontract ",
|
||||
"\ninterface ",
|
||||
"\nlibrary ",
|
||||
"\nconstructor ",
|
||||
"\ntype ",
|
||||
"\nfunction ",
|
||||
"\nevent ",
|
||||
"\nmodifier ",
|
||||
"\nerror ",
|
||||
"\nstruct ",
|
||||
"\nenum ",
|
||||
"\nif ",
|
||||
"\nfor ",
|
||||
"\nwhile ",
|
||||
"\ndo while ",
|
||||
"\nassembly ",
|
||||
"\n\n",
|
||||
"\n",
|
||||
" ",
|
||||
"",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,475 @@
|
||||
mod language;
|
||||
|
||||
pub use self::language::*;
|
||||
|
||||
use super::{DocumentMetadata, RagDocument};
|
||||
|
||||
pub const DEFAULT_SEPARATORS: [&str; 4] = ["\n\n", "\n", " ", ""];
|
||||
|
||||
pub fn get_separators(extension: &str) -> Vec<&'static str> {
|
||||
match extension {
|
||||
"c" | "cc" | "cpp" => Language::Cpp.separators(),
|
||||
"go" => Language::Go.separators(),
|
||||
"java" => Language::Java.separators(),
|
||||
"js" | "mjs" | "cjs" => Language::Js.separators(),
|
||||
"php" => Language::Php.separators(),
|
||||
"proto" => Language::Proto.separators(),
|
||||
"py" => Language::Python.separators(),
|
||||
"rst" => Language::Rst.separators(),
|
||||
"rb" => Language::Ruby.separators(),
|
||||
"rs" => Language::Rust.separators(),
|
||||
"scala" => Language::Scala.separators(),
|
||||
"swift" => Language::Swift.separators(),
|
||||
"md" | "mkd" => Language::Markdown.separators(),
|
||||
"tex" => Language::Latex.separators(),
|
||||
"htm" | "html" => Language::Html.separators(),
|
||||
"sol" => Language::Sol.separators(),
|
||||
_ => DEFAULT_SEPARATORS.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RecursiveCharacterTextSplitter {
|
||||
pub chunk_size: usize,
|
||||
pub chunk_overlap: usize,
|
||||
pub separators: Vec<String>,
|
||||
pub length_function: Box<dyn Fn(&str) -> usize + Send + Sync>,
|
||||
}
|
||||
|
||||
impl Default for RecursiveCharacterTextSplitter {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chunk_size: 1000,
|
||||
chunk_overlap: 20,
|
||||
separators: DEFAULT_SEPARATORS.iter().map(|v| v.to_string()).collect(),
|
||||
length_function: Box::new(|text| text.len()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecursiveCharacterTextSplitter {
|
||||
pub fn new(chunk_size: usize, chunk_overlap: usize, separators: &[&str]) -> Self {
|
||||
Self::default()
|
||||
.with_chunk_size(chunk_size)
|
||||
.with_chunk_overlap(chunk_overlap)
|
||||
.with_separators(separators)
|
||||
}
|
||||
|
||||
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
|
||||
self.chunk_size = chunk_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
|
||||
self.chunk_overlap = chunk_overlap;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_separators(mut self, separators: &[&str]) -> Self {
|
||||
self.separators = separators.iter().map(|v| v.to_string()).collect();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn split_documents(
|
||||
&self,
|
||||
documents: &[RagDocument],
|
||||
chunk_header_options: &SplitterChunkHeaderOptions,
|
||||
) -> Vec<RagDocument> {
|
||||
let mut texts: Vec<String> = Vec::new();
|
||||
let mut metadatas: Vec<DocumentMetadata> = Vec::new();
|
||||
documents.iter().for_each(|d| {
|
||||
if !d.page_content.is_empty() {
|
||||
texts.push(d.page_content.clone());
|
||||
metadatas.push(d.metadata.clone());
|
||||
}
|
||||
});
|
||||
|
||||
self.create_documents(&texts, &metadatas, chunk_header_options)
|
||||
}
|
||||
|
||||
pub fn create_documents(
|
||||
&self,
|
||||
texts: &[String],
|
||||
metadatas: &[DocumentMetadata],
|
||||
chunk_header_options: &SplitterChunkHeaderOptions,
|
||||
) -> Vec<RagDocument> {
|
||||
let SplitterChunkHeaderOptions {
|
||||
chunk_header,
|
||||
chunk_overlap_header,
|
||||
} = chunk_header_options;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
for (i, text) in texts.iter().enumerate() {
|
||||
let mut prev_chunk: Option<String> = None;
|
||||
let mut index_prev_chunk = -1;
|
||||
|
||||
for chunk in self.split_text(text) {
|
||||
let mut page_content = chunk_header.clone();
|
||||
|
||||
let index_chunk = if index_prev_chunk < 0 {
|
||||
text.find(&chunk).map(|i| i as i32).unwrap_or(-1)
|
||||
} else {
|
||||
match text[(index_prev_chunk as usize)..].chars().next() {
|
||||
Some(c) => {
|
||||
let offset = (index_prev_chunk as usize) + c.len_utf8();
|
||||
text[offset..]
|
||||
.find(&chunk)
|
||||
.map(|i| (i + offset) as i32)
|
||||
.unwrap_or(-1)
|
||||
}
|
||||
None => -1,
|
||||
}
|
||||
};
|
||||
|
||||
if prev_chunk.is_some() {
|
||||
if let Some(chunk_overlap_header) = chunk_overlap_header {
|
||||
page_content += chunk_overlap_header;
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = metadatas[i].clone();
|
||||
page_content += &chunk;
|
||||
documents.push(RagDocument {
|
||||
page_content,
|
||||
metadata,
|
||||
});
|
||||
|
||||
prev_chunk = Some(chunk);
|
||||
index_prev_chunk = index_chunk;
|
||||
}
|
||||
}
|
||||
|
||||
documents
|
||||
}
|
||||
|
||||
pub fn split_text(&self, text: &str) -> Vec<String> {
|
||||
let keep_separator = self
|
||||
.separators
|
||||
.iter()
|
||||
.any(|v| v.chars().any(|v| !v.is_whitespace()));
|
||||
self.split_text_impl(text, &self.separators, keep_separator)
|
||||
}
|
||||
|
||||
fn split_text_impl(
|
||||
&self,
|
||||
text: &str,
|
||||
separators: &[String],
|
||||
keep_separator: bool,
|
||||
) -> Vec<String> {
|
||||
let mut final_chunks = Vec::new();
|
||||
|
||||
let mut separator: String = separators.last().cloned().unwrap_or_default();
|
||||
let mut new_separators: Vec<String> = vec![];
|
||||
for (i, s) in separators.iter().enumerate() {
|
||||
if s.is_empty() {
|
||||
separator.clone_from(s);
|
||||
break;
|
||||
}
|
||||
if text.contains(s) {
|
||||
separator.clone_from(s);
|
||||
new_separators = separators[i + 1..].to_vec();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we have the separator, split the text
|
||||
let splits = split_on_separator(text, &separator, keep_separator);
|
||||
|
||||
// Now go merging things, recursively splitting longer texts.
|
||||
let mut good_splits = Vec::new();
|
||||
let _separator = if keep_separator { "" } else { &separator };
|
||||
for s in splits {
|
||||
if (self.length_function)(s) < self.chunk_size {
|
||||
good_splits.push(s.to_string());
|
||||
} else {
|
||||
if !good_splits.is_empty() {
|
||||
let merged_text = self.merge_splits(&good_splits, _separator);
|
||||
final_chunks.extend(merged_text);
|
||||
good_splits.clear();
|
||||
}
|
||||
if new_separators.is_empty() {
|
||||
final_chunks.push(s.to_string());
|
||||
} else {
|
||||
let other_info = self.split_text_impl(s, &new_separators, keep_separator);
|
||||
final_chunks.extend(other_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !good_splits.is_empty() {
|
||||
let merged_text = self.merge_splits(&good_splits, _separator);
|
||||
final_chunks.extend(merged_text);
|
||||
}
|
||||
final_chunks
|
||||
}
|
||||
|
||||
fn merge_splits(&self, splits: &[String], separator: &str) -> Vec<String> {
|
||||
let mut docs = Vec::new();
|
||||
let mut current_doc = Vec::new();
|
||||
let mut total = 0;
|
||||
for d in splits {
|
||||
let _len = (self.length_function)(d);
|
||||
if total + _len + current_doc.len() * separator.len() > self.chunk_size {
|
||||
if total > self.chunk_size {
|
||||
// warn!("Warning: Created a chunk of size {}, which is longer than the specified {}", total, self.chunk_size);
|
||||
}
|
||||
if !current_doc.is_empty() {
|
||||
let doc = self.join_docs(¤t_doc, separator);
|
||||
if let Some(doc) = doc {
|
||||
docs.push(doc);
|
||||
}
|
||||
// Keep on popping if:
|
||||
// - we have a larger chunk than in the chunk overlap
|
||||
// - or if we still have any chunks and the length is long
|
||||
while total > self.chunk_overlap
|
||||
|| (total + _len + current_doc.len() * separator.len() > self.chunk_size
|
||||
&& total > 0)
|
||||
{
|
||||
total -= (self.length_function)(¤t_doc[0]);
|
||||
current_doc.remove(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
current_doc.push(d.to_string());
|
||||
total += _len;
|
||||
}
|
||||
let doc = self.join_docs(¤t_doc, separator);
|
||||
if let Some(doc) = doc {
|
||||
docs.push(doc);
|
||||
}
|
||||
docs
|
||||
}
|
||||
|
||||
fn join_docs(&self, docs: &[String], separator: &str) -> Option<String> {
|
||||
let text = docs.join(separator).trim().to_string();
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SplitterChunkHeaderOptions {
|
||||
pub chunk_header: String,
|
||||
pub chunk_overlap_header: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for SplitterChunkHeaderOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chunk_header: "".into(),
|
||||
chunk_overlap_header: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SplitterChunkHeaderOptions {
|
||||
// Set the value of chunk_header
|
||||
#[allow(unused)]
|
||||
pub fn with_chunk_header(mut self, header: &str) -> Self {
|
||||
self.chunk_header = header.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
// Set the value of chunk_overlap_header
|
||||
#[allow(unused)]
|
||||
pub fn with_chunk_overlap_header(mut self, overlap_header: &str) -> Self {
|
||||
self.chunk_overlap_header = Some(overlap_header.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn split_on_separator<'a>(text: &'a str, separator: &str, keep_separator: bool) -> Vec<&'a str> {
|
||||
let splits: Vec<&str> = if !separator.is_empty() {
|
||||
if keep_separator {
|
||||
let mut splits = Vec::new();
|
||||
let mut prev_idx = 0;
|
||||
let sep_len = separator.len();
|
||||
|
||||
while let Some(idx) = text[prev_idx..].find(separator) {
|
||||
splits.push(&text[prev_idx.saturating_sub(sep_len)..prev_idx + idx]);
|
||||
prev_idx += idx + sep_len;
|
||||
}
|
||||
|
||||
if prev_idx < text.len() {
|
||||
splits.push(&text[prev_idx.saturating_sub(sep_len)..]);
|
||||
}
|
||||
|
||||
splits
|
||||
} else {
|
||||
text.split(separator).collect()
|
||||
}
|
||||
} else {
|
||||
text.split("").collect()
|
||||
};
|
||||
splits.into_iter().filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indexmap::IndexMap;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
fn build_metadata(source: &str) -> Value {
|
||||
json!({ "source": source })
|
||||
}
|
||||
#[test]
|
||||
fn test_split_text() {
|
||||
let splitter = RecursiveCharacterTextSplitter {
|
||||
chunk_size: 7,
|
||||
chunk_overlap: 3,
|
||||
separators: vec![" ".into()],
|
||||
..Default::default()
|
||||
};
|
||||
let output = splitter.split_text("foo bar baz 123");
|
||||
assert_eq!(output, vec!["foo bar", "bar baz", "baz 123"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_document() {
|
||||
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
|
||||
let chunk_header_options = SplitterChunkHeaderOptions::default();
|
||||
let mut metadata1 = IndexMap::new();
|
||||
metadata1.insert("source".into(), "1".into());
|
||||
let mut metadata2 = IndexMap::new();
|
||||
metadata2.insert("source".into(), "2".into());
|
||||
let output = splitter.create_documents(
|
||||
&["foo bar".into(), "baz".into()],
|
||||
&[metadata1, metadata2],
|
||||
&chunk_header_options,
|
||||
);
|
||||
let output = json!(output);
|
||||
assert_eq!(
|
||||
output,
|
||||
json!([
|
||||
{
|
||||
"page_content": "foo",
|
||||
"metadata": build_metadata("1"),
|
||||
},
|
||||
{
|
||||
"page_content": "bar",
|
||||
"metadata": build_metadata("1"),
|
||||
},
|
||||
{
|
||||
"page_content": "baz",
|
||||
"metadata": build_metadata("2"),
|
||||
},
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_header() {
|
||||
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
|
||||
let chunk_header_options = SplitterChunkHeaderOptions::default()
|
||||
.with_chunk_header("SOURCE NAME: testing\n-----\n")
|
||||
.with_chunk_overlap_header("(cont'd) ");
|
||||
let mut metadata1 = IndexMap::new();
|
||||
metadata1.insert("source".into(), "1".into());
|
||||
let mut metadata2 = IndexMap::new();
|
||||
metadata2.insert("source".into(), "2".into());
|
||||
let output = splitter.create_documents(
|
||||
&["foo bar".into(), "baz".into()],
|
||||
&[metadata1, metadata2],
|
||||
&chunk_header_options,
|
||||
);
|
||||
let output = json!(output);
|
||||
assert_eq!(
|
||||
output,
|
||||
json!([
|
||||
{
|
||||
"page_content": "SOURCE NAME: testing\n-----\nfoo",
|
||||
"metadata": build_metadata("1"),
|
||||
},
|
||||
{
|
||||
"page_content": "SOURCE NAME: testing\n-----\n(cont'd) bar",
|
||||
"metadata": build_metadata("1"),
|
||||
},
|
||||
{
|
||||
"page_content": "SOURCE NAME: testing\n-----\nbaz",
|
||||
"metadata": build_metadata("2"),
|
||||
},
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_splitter() {
|
||||
let text = r#"# 🦜️🔗 LangChain
|
||||
|
||||
⚡ Building applications with LLMs through composability ⚡
|
||||
|
||||
## Quick Install
|
||||
|
||||
```bash
|
||||
# Hopefully this code block isn't split
|
||||
pip install langchain
|
||||
```
|
||||
|
||||
As an open source project in a rapidly developing field, we are extremely open to contributions."#;
|
||||
let splitter =
|
||||
RecursiveCharacterTextSplitter::new(100, 0, &Language::Markdown.separators());
|
||||
let output = splitter.split_text(text);
|
||||
let expected_output = vec![
|
||||
"# 🦜️🔗 LangChain\n\n⚡ Building applications with LLMs through composability ⚡",
|
||||
"## Quick Install\n\n```bash\n# Hopefully this code block isn't split\npip install langchain",
|
||||
"```",
|
||||
"As an open source project in a rapidly developing field, we are extremely open to contributions.",
|
||||
];
|
||||
assert_eq!(output, expected_output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_html_splitter() {
|
||||
let text = r#"<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>🦜️🔗 LangChain</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
h1 {
|
||||
color: darkblue;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<h1>🦜️🔗 LangChain</h1>
|
||||
<p>⚡ Building applications with LLMs through composability ⚡</p>
|
||||
</div>
|
||||
<div>
|
||||
As an open source project in a rapidly developing field, we are extremely open to contributions.
|
||||
</div>
|
||||
</body>
|
||||
</html>"#;
|
||||
let splitter = RecursiveCharacterTextSplitter::new(175, 20, &Language::Html.separators());
|
||||
let output = splitter.split_text(text);
|
||||
let expected_output = vec![
|
||||
"<!DOCTYPE html>\n<html>",
|
||||
"<head>\n <title>🦜️🔗 LangChain</title>",
|
||||
r#"<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
h1 {
|
||||
color: darkblue;
|
||||
}
|
||||
</style>
|
||||
</head>"#,
|
||||
r#"<body>
|
||||
<div>
|
||||
<h1>🦜️🔗 LangChain</h1>
|
||||
<p>⚡ Building applications with LLMs through composability ⚡</p>
|
||||
</div>"#,
|
||||
r#"<div>
|
||||
As an open source project in a rapidly developing field, we are extremely open to contributions.
|
||||
</div>
|
||||
</body>
|
||||
</html>"#,
|
||||
];
|
||||
assert_eq!(output, expected_output);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,393 @@
|
||||
use crate::utils::decode_bin;
|
||||
|
||||
use ansi_colours::AsRGB;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use crossterm::style::{Color, Stylize};
|
||||
use crossterm::terminal;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::LazyLock;
|
||||
use syntect::highlighting::{Color as SyntectColor, FontStyle, Style, Theme};
|
||||
use syntect::parsing::SyntaxSet;
|
||||
use syntect::{easy::HighlightLines, parsing::SyntaxReference};
|
||||
|
||||
/// Comes from <https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin>
|
||||
const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin");
|
||||
|
||||
static LANG_MAPS: LazyLock<HashMap<String, String>> = LazyLock::new(|| {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("csharp".into(), "C#".into());
|
||||
m.insert("php".into(), "PHP Source".into());
|
||||
m
|
||||
});
|
||||
|
||||
pub struct MarkdownRender {
|
||||
options: RenderOptions,
|
||||
syntax_set: SyntaxSet,
|
||||
code_color: Option<Color>,
|
||||
md_syntax: SyntaxReference,
|
||||
code_syntax: Option<SyntaxReference>,
|
||||
prev_line_type: LineType,
|
||||
wrap_width: Option<u16>,
|
||||
}
|
||||
|
||||
impl MarkdownRender {
|
||||
pub fn init(options: RenderOptions) -> Result<Self> {
|
||||
let syntax_set: SyntaxSet =
|
||||
decode_bin(SYNTAXES).with_context(|| "MarkdownRender: invalid syntaxes binary")?;
|
||||
|
||||
let code_color = options
|
||||
.theme
|
||||
.as_ref()
|
||||
.map(|theme| get_code_color(theme, options.truecolor));
|
||||
let md_syntax = syntax_set.find_syntax_by_extension("md").unwrap().clone();
|
||||
let line_type = LineType::Normal;
|
||||
let wrap_width = match options.wrap.as_deref() {
|
||||
None => None,
|
||||
Some(value) => match terminal::size() {
|
||||
Ok((columns, _)) => {
|
||||
if value == "auto" {
|
||||
Some(columns)
|
||||
} else {
|
||||
let value = value
|
||||
.parse::<u16>()
|
||||
.map_err(|_| anyhow!("Invalid wrap value"))?;
|
||||
Some(columns.min(value))
|
||||
}
|
||||
}
|
||||
Err(_) => None,
|
||||
},
|
||||
};
|
||||
Ok(Self {
|
||||
syntax_set,
|
||||
code_color,
|
||||
md_syntax,
|
||||
code_syntax: None,
|
||||
prev_line_type: line_type,
|
||||
wrap_width,
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn render(&mut self, text: &str) -> String {
|
||||
text.split('\n')
|
||||
.map(|line| self.render_line_mut(line))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn render_line(&self, line: &str) -> String {
|
||||
let (_, code_syntax, is_code) = self.check_line(line);
|
||||
if is_code {
|
||||
self.highlight_code_line(line, &code_syntax)
|
||||
} else {
|
||||
self.highlight_line(line, &self.md_syntax, false)
|
||||
}
|
||||
}
|
||||
|
||||
fn render_line_mut(&mut self, line: &str) -> String {
|
||||
let (line_type, code_syntax, is_code) = self.check_line(line);
|
||||
let output = if is_code {
|
||||
self.highlight_code_line(line, &code_syntax)
|
||||
} else {
|
||||
self.highlight_line(line, &self.md_syntax, false)
|
||||
};
|
||||
self.prev_line_type = line_type;
|
||||
self.code_syntax = code_syntax;
|
||||
output
|
||||
}
|
||||
|
||||
fn check_line(&self, line: &str) -> (LineType, Option<SyntaxReference>, bool) {
|
||||
let mut line_type = self.prev_line_type;
|
||||
let mut code_syntax = self.code_syntax.clone();
|
||||
let mut is_code = false;
|
||||
if let Some(lang) = detect_code_block(line) {
|
||||
match line_type {
|
||||
LineType::Normal | LineType::CodeEnd => {
|
||||
line_type = LineType::CodeBegin;
|
||||
code_syntax = if lang.is_empty() {
|
||||
None
|
||||
} else {
|
||||
self.find_syntax(&lang).cloned()
|
||||
};
|
||||
}
|
||||
LineType::CodeBegin | LineType::CodeInner => {
|
||||
line_type = LineType::CodeEnd;
|
||||
code_syntax = None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match line_type {
|
||||
LineType::Normal => {}
|
||||
LineType::CodeEnd => {
|
||||
line_type = LineType::Normal;
|
||||
}
|
||||
LineType::CodeBegin => {
|
||||
if code_syntax.is_none() {
|
||||
if let Some(syntax) = self.syntax_set.find_syntax_by_first_line(line) {
|
||||
code_syntax = Some(syntax.clone());
|
||||
}
|
||||
}
|
||||
line_type = LineType::CodeInner;
|
||||
is_code = true;
|
||||
}
|
||||
LineType::CodeInner => {
|
||||
is_code = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
(line_type, code_syntax, is_code)
|
||||
}
|
||||
|
||||
fn highlight_line(&self, line: &str, syntax: &SyntaxReference, is_code: bool) -> String {
|
||||
let ws: String = line.chars().take_while(|c| c.is_whitespace()).collect();
|
||||
let trimmed_line: &str = &line[ws.len()..];
|
||||
let mut line_highlighted = None;
|
||||
if let Some(theme) = &self.options.theme {
|
||||
let mut highlighter = HighlightLines::new(syntax, theme);
|
||||
if let Ok(ranges) = highlighter.highlight_line(trimmed_line, &self.syntax_set) {
|
||||
line_highlighted = Some(format!(
|
||||
"{ws}{}",
|
||||
as_terminal_escaped(&ranges, self.options.truecolor)
|
||||
))
|
||||
}
|
||||
}
|
||||
let line = line_highlighted.unwrap_or_else(|| line.into());
|
||||
self.wrap_line(line, is_code)
|
||||
}
|
||||
|
||||
fn highlight_code_line(&self, line: &str, code_syntax: &Option<SyntaxReference>) -> String {
|
||||
if let Some(syntax) = code_syntax {
|
||||
self.highlight_line(line, syntax, true)
|
||||
} else {
|
||||
let line = match self.code_color {
|
||||
Some(color) => line.with(color).to_string(),
|
||||
None => line.to_string(),
|
||||
};
|
||||
self.wrap_line(line, true)
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_line(&self, line: String, is_code: bool) -> String {
|
||||
if let Some(width) = self.wrap_width {
|
||||
if is_code && !self.options.wrap_code {
|
||||
return line;
|
||||
}
|
||||
wrap(&line, width as usize)
|
||||
} else {
|
||||
line
|
||||
}
|
||||
}
|
||||
|
||||
fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> {
|
||||
if let Some(new_lang) = LANG_MAPS.get(&lang.to_ascii_lowercase()) {
|
||||
self.syntax_set.find_syntax_by_name(new_lang)
|
||||
} else {
|
||||
self.syntax_set
|
||||
.find_syntax_by_token(lang)
|
||||
.or_else(|| self.syntax_set.find_syntax_by_extension(lang))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap(text: &str, width: usize) -> String {
|
||||
let indent: usize = text.chars().take_while(|c| *c == ' ').count();
|
||||
let wrap_options = textwrap::Options::new(width)
|
||||
.wrap_algorithm(textwrap::WrapAlgorithm::FirstFit)
|
||||
.initial_indent(&text[0..indent]);
|
||||
textwrap::wrap(&text[indent..], wrap_options).join("\n")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RenderOptions {
|
||||
pub theme: Option<Theme>,
|
||||
pub wrap: Option<String>,
|
||||
pub wrap_code: bool,
|
||||
pub truecolor: bool,
|
||||
}
|
||||
|
||||
impl RenderOptions {
|
||||
pub(crate) fn new(
|
||||
theme: Option<Theme>,
|
||||
wrap: Option<String>,
|
||||
wrap_code: bool,
|
||||
truecolor: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
theme,
|
||||
wrap,
|
||||
wrap_code,
|
||||
truecolor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LineType {
|
||||
Normal,
|
||||
CodeBegin,
|
||||
CodeInner,
|
||||
CodeEnd,
|
||||
}
|
||||
|
||||
fn as_terminal_escaped(ranges: &[(Style, &str)], truecolor: bool) -> String {
|
||||
let mut output = String::new();
|
||||
for (style, text) in ranges {
|
||||
let fg = blend_fg_color(style.foreground, style.background);
|
||||
let mut text = text.with(convert_color(fg, truecolor));
|
||||
if style.font_style.contains(FontStyle::BOLD) {
|
||||
text = text.bold();
|
||||
}
|
||||
if style.font_style.contains(FontStyle::UNDERLINE) {
|
||||
text = text.underlined();
|
||||
}
|
||||
output.push_str(&text.to_string());
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn convert_color(c: SyntectColor, truecolor: bool) -> Color {
|
||||
if truecolor {
|
||||
Color::Rgb {
|
||||
r: c.r,
|
||||
g: c.g,
|
||||
b: c.b,
|
||||
}
|
||||
} else {
|
||||
let value = (c.r, c.g, c.b).to_ansi256();
|
||||
// lower contrast
|
||||
let value = match value {
|
||||
7 | 15 | 231 | 252..=255 => 252,
|
||||
_ => value,
|
||||
};
|
||||
Color::AnsiValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn blend_fg_color(fg: SyntectColor, bg: SyntectColor) -> SyntectColor {
|
||||
if fg.a == 0xff {
|
||||
return fg;
|
||||
}
|
||||
let ratio = u32::from(fg.a);
|
||||
let r = (u32::from(fg.r) * ratio + u32::from(bg.r) * (255 - ratio)) / 255;
|
||||
let g = (u32::from(fg.g) * ratio + u32::from(bg.g) * (255 - ratio)) / 255;
|
||||
let b = (u32::from(fg.b) * ratio + u32::from(bg.b) * (255 - ratio)) / 255;
|
||||
SyntectColor {
|
||||
r: u8::try_from(r).unwrap_or(u8::MAX),
|
||||
g: u8::try_from(g).unwrap_or(u8::MAX),
|
||||
b: u8::try_from(b).unwrap_or(u8::MAX),
|
||||
a: 255,
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_code_block(line: &str) -> Option<String> {
|
||||
let line = line.trim_start();
|
||||
if !line.starts_with("```") {
|
||||
return None;
|
||||
}
|
||||
let lang = line
|
||||
.chars()
|
||||
.skip(3)
|
||||
.take_while(|v| !v.is_whitespace())
|
||||
.collect();
|
||||
Some(lang)
|
||||
}
|
||||
|
||||
fn get_code_color(theme: &Theme, truecolor: bool) -> Color {
|
||||
let scope = theme.scopes.iter().find(|v| {
|
||||
v.scope
|
||||
.selectors
|
||||
.iter()
|
||||
.any(|v| v.path.scopes.iter().any(|v| v.to_string() == "string"))
|
||||
});
|
||||
scope
|
||||
.and_then(|v| v.style.foreground)
|
||||
.map_or_else(|| Color::Yellow, |c| convert_color(c, truecolor))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const TEXT: &str = r#"
|
||||
To unzip a file in Rust, you can use the `zip` crate. Here's an example code that shows how to unzip a file:
|
||||
|
||||
```rust
|
||||
use std::fs::File;
|
||||
|
||||
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
todo!()
|
||||
}
|
||||
```
|
||||
"#;
|
||||
const TEXT_NO_WRAP_CODE: &str = r#"
|
||||
To unzip a file in Rust, you can use the `zip` crate. Here's an example code
|
||||
that shows how to unzip a file:
|
||||
|
||||
```rust
|
||||
use std::fs::File;
|
||||
|
||||
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
todo!()
|
||||
}
|
||||
```
|
||||
"#;
|
||||
|
||||
const TEXT_WRAP_ALL: &str = r#"
|
||||
To unzip a file in Rust, you can use the `zip` crate. Here's an example code
|
||||
that shows how to unzip a file:
|
||||
|
||||
```rust
|
||||
use std::fs::File;
|
||||
|
||||
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn
|
||||
std::error::Error>> {
|
||||
todo!()
|
||||
}
|
||||
```
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_render() {
|
||||
let options = RenderOptions::default();
|
||||
let render = MarkdownRender::init(options).unwrap();
|
||||
assert!(render.find_syntax("csharp").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_theme() {
|
||||
let options = RenderOptions::default();
|
||||
let mut render = MarkdownRender::init(options).unwrap();
|
||||
let output = render.render(TEXT);
|
||||
assert_eq!(TEXT, output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_wrap_code() {
|
||||
let options = RenderOptions::default();
|
||||
let mut render = MarkdownRender::init(options).unwrap();
|
||||
render.wrap_width = Some(80);
|
||||
let output = render.render(TEXT);
|
||||
assert_eq!(TEXT_NO_WRAP_CODE, output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrap_all() {
|
||||
let options = RenderOptions {
|
||||
wrap_code: true,
|
||||
..Default::default()
|
||||
};
|
||||
let mut render = MarkdownRender::init(options).unwrap();
|
||||
render.wrap_width = Some(80);
|
||||
let output = render.render(TEXT);
|
||||
assert_eq!(TEXT_WRAP_ALL, output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_code_block() {
|
||||
assert_eq!(detect_code_block("```rust"), Some("rust".into()));
|
||||
assert_eq!(detect_code_block("```c++"), Some("c++".into()));
|
||||
assert_eq!(detect_code_block(" ```rust"), Some("rust".into()));
|
||||
assert_eq!(detect_code_block("```"), Some("".into()));
|
||||
assert_eq!(detect_code_block("``rust"), None);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
mod markdown;
|
||||
mod stream;
|
||||
|
||||
pub use self::markdown::{MarkdownRender, RenderOptions};
|
||||
use self::stream::{markdown_stream, raw_stream};
|
||||
|
||||
use crate::utils::{error_text, pretty_error, AbortSignal, IS_STDOUT_TERMINAL};
|
||||
use crate::{client::SseEvent, config::GlobalConfig};
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
pub async fn render_stream(
|
||||
rx: UnboundedReceiver<SseEvent>,
|
||||
config: &GlobalConfig,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let ret = if *IS_STDOUT_TERMINAL && config.read().highlight {
|
||||
let render_options = config.read().render_options()?;
|
||||
let mut render = MarkdownRender::init(render_options)?;
|
||||
markdown_stream(rx, &mut render, &abort_signal).await
|
||||
} else {
|
||||
raw_stream(rx, &abort_signal).await
|
||||
};
|
||||
ret.map_err(|err| err.context("Failed to reader stream"))
|
||||
}
|
||||
|
||||
pub fn render_error(err: anyhow::Error) {
|
||||
eprintln!("{}", error_text(&pretty_error(&err)));
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
use super::{MarkdownRender, SseEvent};
|
||||
|
||||
use crate::utils::{poll_abort_signal, spawn_spinner, AbortSignal};
|
||||
|
||||
use anyhow::Result;
|
||||
use crossterm::{
|
||||
cursor, queue, style,
|
||||
terminal::{self, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use std::{
|
||||
io::{stdout, Stdout, Write},
|
||||
time::Duration,
|
||||
};
|
||||
use textwrap::core::display_width;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
pub async fn markdown_stream(
|
||||
rx: UnboundedReceiver<SseEvent>,
|
||||
render: &mut MarkdownRender,
|
||||
abort_signal: &AbortSignal,
|
||||
) -> Result<()> {
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = stdout();
|
||||
|
||||
let ret = markdown_stream_inner(rx, render, abort_signal, &mut stdout).await;
|
||||
|
||||
disable_raw_mode()?;
|
||||
|
||||
if ret.is_err() {
|
||||
println!();
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn raw_stream(
|
||||
mut rx: UnboundedReceiver<SseEvent>,
|
||||
abort_signal: &AbortSignal,
|
||||
) -> Result<()> {
|
||||
let mut spinner = Some(spawn_spinner("Generating"));
|
||||
|
||||
loop {
|
||||
if abort_signal.aborted() {
|
||||
break;
|
||||
}
|
||||
if let Some(evt) = rx.recv().await {
|
||||
if let Some(spinner) = spinner.take() {
|
||||
spinner.stop();
|
||||
}
|
||||
|
||||
match evt {
|
||||
SseEvent::Text(text) => {
|
||||
print!("{text}");
|
||||
stdout().flush()?;
|
||||
}
|
||||
SseEvent::Done => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(spinner) = spinner.take() {
|
||||
spinner.stop();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn markdown_stream_inner(
|
||||
mut rx: UnboundedReceiver<SseEvent>,
|
||||
render: &mut MarkdownRender,
|
||||
abort_signal: &AbortSignal,
|
||||
writer: &mut Stdout,
|
||||
) -> Result<()> {
|
||||
let mut buffer = String::new();
|
||||
let mut buffer_rows = 1;
|
||||
|
||||
let columns = terminal::size()?.0;
|
||||
|
||||
let mut spinner = Some(spawn_spinner("Generating"));
|
||||
|
||||
'outer: loop {
|
||||
if abort_signal.aborted() {
|
||||
break;
|
||||
}
|
||||
for reply_event in gather_events(&mut rx).await {
|
||||
if let Some(spinner) = spinner.take() {
|
||||
spinner.stop();
|
||||
}
|
||||
|
||||
match reply_event {
|
||||
SseEvent::Text(mut text) => {
|
||||
// tab width hacking
|
||||
text = text.replace('\t', " ");
|
||||
|
||||
let mut attempts = 0;
|
||||
let (col, mut row) = loop {
|
||||
match cursor::position() {
|
||||
Ok(pos) => break pos,
|
||||
Err(_) if attempts < 3 => attempts += 1,
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
};
|
||||
|
||||
// Fix unexpected duplicate lines on kitty
|
||||
if col == 0 && row > 0 && display_width(&buffer) == columns as usize {
|
||||
row -= 1;
|
||||
}
|
||||
|
||||
if row + 1 >= buffer_rows {
|
||||
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows),)?;
|
||||
} else {
|
||||
let scroll_rows = buffer_rows - row - 1;
|
||||
queue!(
|
||||
writer,
|
||||
terminal::ScrollUp(scroll_rows),
|
||||
cursor::MoveTo(0, 0),
|
||||
)?;
|
||||
}
|
||||
|
||||
// No guarantee that text returned by render will not be re-layouted, so it is better to clear it.
|
||||
queue!(writer, terminal::Clear(terminal::ClearType::FromCursorDown))?;
|
||||
|
||||
if text.contains('\n') {
|
||||
let text = format!("{buffer}{text}");
|
||||
let (head, tail) = split_line_tail(&text);
|
||||
let output = render.render(head);
|
||||
print_block(writer, &output, columns)?;
|
||||
buffer = tail.to_string();
|
||||
} else {
|
||||
buffer = format!("{buffer}{text}");
|
||||
}
|
||||
|
||||
let output = render.render_line(&buffer);
|
||||
if output.contains('\n') {
|
||||
let (head, tail) = split_line_tail(&output);
|
||||
buffer_rows = print_block(writer, head, columns)?;
|
||||
queue!(writer, style::Print(&tail),)?;
|
||||
|
||||
// No guarantee the buffer width of the buffer will not exceed the number of columns.
|
||||
// So we calculate the number of rows needed, rather than setting it directly to 1.
|
||||
buffer_rows += need_rows(tail, columns);
|
||||
} else {
|
||||
queue!(writer, style::Print(&output))?;
|
||||
buffer_rows = need_rows(&output, columns);
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
}
|
||||
SseEvent::Done => {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if poll_abort_signal(abort_signal)? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(spinner) = spinner.take() {
|
||||
spinner.stop();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn gather_events(rx: &mut UnboundedReceiver<SseEvent>) -> Vec<SseEvent> {
|
||||
let mut texts = vec![];
|
||||
let mut done = false;
|
||||
tokio::select! {
|
||||
_ = async {
|
||||
while let Some(reply_event) = rx.recv().await {
|
||||
match reply_event {
|
||||
SseEvent::Text(v) => texts.push(v),
|
||||
SseEvent::Done => {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} => {}
|
||||
_ = tokio::time::sleep(Duration::from_millis(50)) => {}
|
||||
}
|
||||
let mut events = vec![];
|
||||
if !texts.is_empty() {
|
||||
events.push(SseEvent::Text(texts.join("")))
|
||||
}
|
||||
if done {
|
||||
events.push(SseEvent::Done)
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
fn print_block(writer: &mut Stdout, text: &str, columns: u16) -> Result<u16> {
|
||||
let mut num = 0;
|
||||
for line in text.split('\n') {
|
||||
queue!(
|
||||
writer,
|
||||
style::Print(line),
|
||||
style::Print("\n"),
|
||||
cursor::MoveLeft(columns),
|
||||
)?;
|
||||
num += 1;
|
||||
}
|
||||
Ok(num)
|
||||
}
|
||||
|
||||
fn split_line_tail(text: &str) -> (&str, &str) {
|
||||
if let Some((head, tail)) = text.rsplit_once('\n') {
|
||||
(head, tail)
|
||||
} else {
|
||||
("", text)
|
||||
}
|
||||
}
|
||||
|
||||
fn need_rows(text: &str, columns: u16) -> u16 {
|
||||
let buffer_width = display_width(text).max(1) as u16;
|
||||
buffer_width.div_ceil(columns)
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
use super::{ReplCommand, REPL_COMMANDS};
|
||||
|
||||
use crate::{config::GlobalConfig, utils::fuzzy_filter};
|
||||
|
||||
use reedline::{Completer, Span, Suggestion};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl Completer for ReplCompleter {
|
||||
fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
|
||||
let mut suggestions = vec![];
|
||||
let line = &line[0..pos];
|
||||
let mut parts = split_line(line);
|
||||
if parts.is_empty() {
|
||||
return suggestions;
|
||||
}
|
||||
if parts[0].0 == r#":::"# {
|
||||
parts.remove(0);
|
||||
}
|
||||
|
||||
let parts_len = parts.len();
|
||||
if parts_len == 0 {
|
||||
return suggestions;
|
||||
}
|
||||
let (cmd, cmd_start) = parts[0];
|
||||
|
||||
if !cmd.starts_with('.') {
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
let state = self.config.read().state();
|
||||
|
||||
let command_filter = parts
|
||||
.iter()
|
||||
.take(2)
|
||||
.map(|(v, _)| *v)
|
||||
.collect::<Vec<&str>>()
|
||||
.join(" ");
|
||||
let commands: Vec<_> = self
|
||||
.commands
|
||||
.iter()
|
||||
.filter(|cmd| {
|
||||
cmd.is_valid(state)
|
||||
&& (command_filter.len() == 1 || cmd.name.starts_with(&command_filter[..2]))
|
||||
})
|
||||
.collect();
|
||||
let commands = fuzzy_filter(commands, |v| v.name, &command_filter);
|
||||
|
||||
if parts_len > 1 {
|
||||
let span = Span::new(parts[parts_len - 1].1, pos);
|
||||
let args_line = &line[parts[1].1..];
|
||||
let args: Vec<&str> = parts.iter().skip(1).map(|(v, _)| *v).collect();
|
||||
suggestions.extend(
|
||||
self.config
|
||||
.read()
|
||||
.repl_complete(cmd, &args, args_line)
|
||||
.iter()
|
||||
.map(|(value, description)| {
|
||||
let description = description.as_deref().unwrap_or_default();
|
||||
create_suggestion(value, description, span)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
if suggestions.is_empty() {
|
||||
let span = Span::new(cmd_start, pos);
|
||||
suggestions.extend(commands.iter().map(|cmd| {
|
||||
let name = cmd.name;
|
||||
let description = cmd.description;
|
||||
let has_group = self.groups.get(name).map(|v| *v > 1).unwrap_or_default();
|
||||
let name = if has_group {
|
||||
name.to_string()
|
||||
} else {
|
||||
format!("{name} ")
|
||||
};
|
||||
create_suggestion(&name, description, span)
|
||||
}))
|
||||
}
|
||||
suggestions
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReplCompleter {
|
||||
config: GlobalConfig,
|
||||
commands: Vec<ReplCommand>,
|
||||
groups: HashMap<&'static str, usize>,
|
||||
}
|
||||
|
||||
impl ReplCompleter {
|
||||
pub fn new(config: &GlobalConfig) -> Self {
|
||||
let mut groups = HashMap::new();
|
||||
|
||||
let commands: Vec<ReplCommand> = REPL_COMMANDS.to_vec();
|
||||
|
||||
for cmd in REPL_COMMANDS.iter() {
|
||||
let name = cmd.name;
|
||||
*groups.entry(name).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
Self {
|
||||
config: config.clone(),
|
||||
commands,
|
||||
groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_suggestion(value: &str, description: &str, span: Span) -> Suggestion {
|
||||
let description = if description.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(description.to_string())
|
||||
};
|
||||
Suggestion {
|
||||
value: value.to_string(),
|
||||
description,
|
||||
style: None,
|
||||
extra: None,
|
||||
span,
|
||||
append_whitespace: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn split_line(line: &str) -> Vec<(&str, usize)> {
|
||||
let mut parts = vec![];
|
||||
let mut part_start = None;
|
||||
for (i, ch) in line.char_indices() {
|
||||
if ch == ' ' {
|
||||
if let Some(s) = part_start {
|
||||
parts.push((&line[s..i], s));
|
||||
part_start = None;
|
||||
}
|
||||
} else if part_start.is_none() {
|
||||
part_start = Some(i)
|
||||
}
|
||||
}
|
||||
if let Some(s) = part_start {
|
||||
parts.push((&line[s..], s));
|
||||
} else {
|
||||
parts.push(("", line.len()))
|
||||
}
|
||||
parts
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_line() {
|
||||
assert_eq!(split_line(".role coder"), vec![(".role", 0), ("coder", 6)],);
|
||||
assert_eq!(
|
||||
split_line(" .role coder"),
|
||||
vec![(".role", 1), ("coder", 9)],
|
||||
);
|
||||
assert_eq!(
|
||||
split_line(".set highlight "),
|
||||
vec![(".set", 0), ("highlight", 5), ("", 15)],
|
||||
);
|
||||
assert_eq!(
|
||||
split_line(".set highlight t"),
|
||||
vec![(".set", 0), ("highlight", 5), ("t", 15)],
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use super::REPL_COMMANDS;
|
||||
|
||||
use crate::{config::GlobalConfig, utils::NO_COLOR};
|
||||
|
||||
use nu_ansi_term::{Color, Style};
|
||||
use reedline::{Highlighter, StyledText};
|
||||
|
||||
const DEFAULT_COLOR: Color = Color::Default;
|
||||
const MATCH_COLOR: Color = Color::Green;
|
||||
|
||||
pub struct ReplHighlighter;
|
||||
|
||||
impl ReplHighlighter {
|
||||
pub fn new(_config: &GlobalConfig) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Highlighter for ReplHighlighter {
|
||||
fn highlight(&self, line: &str, _cursor: usize) -> StyledText {
|
||||
let mut styled_text = StyledText::new();
|
||||
|
||||
if *NO_COLOR {
|
||||
styled_text.push((Style::default(), line.to_string()));
|
||||
} else if REPL_COMMANDS.iter().any(|cmd| line.contains(cmd.name)) {
|
||||
let matches: Vec<&str> = REPL_COMMANDS
|
||||
.iter()
|
||||
.filter(|cmd| line.contains(cmd.name))
|
||||
.map(|cmd| cmd.name)
|
||||
.collect();
|
||||
let longest_match = matches.iter().fold(String::new(), |acc, &item| {
|
||||
if item.len() > acc.len() {
|
||||
item.to_string()
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
});
|
||||
let buffer_split: Vec<&str> = line.splitn(2, &longest_match).collect();
|
||||
|
||||
styled_text.push((Style::new().fg(DEFAULT_COLOR), buffer_split[0].to_string()));
|
||||
styled_text.push((Style::new().fg(MATCH_COLOR), longest_match));
|
||||
styled_text.push((Style::new().fg(DEFAULT_COLOR), buffer_split[1].to_string()));
|
||||
} else {
|
||||
styled_text.push((Style::new().fg(DEFAULT_COLOR), line.to_string()));
|
||||
}
|
||||
|
||||
styled_text
|
||||
}
|
||||
}
|
||||
+1014
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
use crate::config::GlobalConfig;
|
||||
|
||||
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReplPrompt {
|
||||
config: GlobalConfig,
|
||||
}
|
||||
|
||||
impl ReplPrompt {
|
||||
pub fn new(config: &GlobalConfig) -> Self {
|
||||
Self {
|
||||
config: config.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt for ReplPrompt {
|
||||
fn render_prompt_left(&self) -> Cow<'_, str> {
|
||||
Cow::Owned(self.config.read().render_prompt_left())
|
||||
}
|
||||
|
||||
fn render_prompt_right(&self) -> Cow<'_, str> {
|
||||
Cow::Owned(self.config.read().render_prompt_right())
|
||||
}
|
||||
|
||||
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<'_, str> {
|
||||
Cow::Borrowed("")
|
||||
}
|
||||
|
||||
fn render_prompt_multiline_indicator(&self) -> Cow<'_, str> {
|
||||
Cow::Borrowed("... ")
|
||||
}
|
||||
|
||||
fn render_prompt_history_search_indicator(
|
||||
&self,
|
||||
history_search: PromptHistorySearch,
|
||||
) -> Cow<'_, str> {
|
||||
let prefix = match history_search.status {
|
||||
PromptHistorySearchStatus::Passing => "",
|
||||
PromptHistorySearchStatus::Failing => "failing ",
|
||||
};
|
||||
// NOTE: magic strings, given there is logic on how these compose I am not sure if it
|
||||
// is worth extracting in to static constant
|
||||
Cow::Owned(format!(
|
||||
"({}reverse-search: {}) ",
|
||||
prefix, history_search.term
|
||||
))
|
||||
}
|
||||
}
|
||||
+935
@@ -0,0 +1,935 @@
|
||||
use crate::{client::*, config::*, function::*, rag::*, utils::*};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use bytes::Bytes;
|
||||
use chrono::{Timelike, Utc};
|
||||
use futures_util::StreamExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
|
||||
use hyper::{
|
||||
body::{Frame, Incoming},
|
||||
service::service_fn,
|
||||
};
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
net::IpAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::{
|
||||
net::TcpListener,
|
||||
sync::{
|
||||
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
|
||||
oneshot,
|
||||
},
|
||||
};
|
||||
use tokio_graceful::Shutdown;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
const DEFAULT_MODEL_NAME: &str = "default";
|
||||
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
|
||||
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
|
||||
|
||||
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
|
||||
|
||||
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
|
||||
let addr = match addr {
|
||||
Some(addr) => {
|
||||
if let Ok(port) = addr.parse::<u16>() {
|
||||
format!("127.0.0.1:{port}")
|
||||
} else if let Ok(ip) = addr.parse::<IpAddr>() {
|
||||
format!("{ip}:8000")
|
||||
} else {
|
||||
addr
|
||||
}
|
||||
}
|
||||
None => config.read().serve_addr(),
|
||||
};
|
||||
let server = Arc::new(Server::new(&config));
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
let stop_server = server.run(listener).await?;
|
||||
println!("Chat Completions API: http://{addr}/v1/chat/completions");
|
||||
println!("Embeddings API: http://{addr}/v1/embeddings");
|
||||
println!("Rerank API: http://{addr}/v1/rerank");
|
||||
println!("LLM Playground: http://{addr}/playground");
|
||||
println!("LLM Arena: http://{addr}/arena?num=2");
|
||||
shutdown_signal().await;
|
||||
let _ = stop_server.send(());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Server {
|
||||
config: Config,
|
||||
models: Vec<Value>,
|
||||
roles: Vec<Role>,
|
||||
rags: Vec<String>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
fn new(config: &GlobalConfig) -> Self {
|
||||
let mut config = config.read().clone();
|
||||
config.functions = Functions::default();
|
||||
let mut models = list_all_models(&config);
|
||||
let mut default_model = config.model.clone();
|
||||
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
|
||||
models.insert(0, &default_model);
|
||||
let models: Vec<Value> = models
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, model)| {
|
||||
let id = if i == 0 {
|
||||
DEFAULT_MODEL_NAME.into()
|
||||
} else {
|
||||
model.id()
|
||||
};
|
||||
let mut value = json!(model.data());
|
||||
if let Some(value_obj) = value.as_object_mut() {
|
||||
value_obj.insert("id".into(), id.into());
|
||||
value_obj.insert("object".into(), "model".into());
|
||||
value_obj.insert("owned_by".into(), model.client_name().into());
|
||||
value_obj.remove("name");
|
||||
}
|
||||
value
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
config,
|
||||
models,
|
||||
roles: Config::all_roles(),
|
||||
rags: Config::list_rags(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
|
||||
let guard = shutdown.guard_weak();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = listener.accept() => {
|
||||
let Ok((cnx, _)) = res else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let stream = TokioIo::new(cnx);
|
||||
let server = self.clone();
|
||||
shutdown.spawn_task(async move {
|
||||
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
|
||||
server.clone().handle(request)
|
||||
});
|
||||
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(stream, hyper_service)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
_ = guard.cancelled() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(tx)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
self: Arc<Self>,
|
||||
req: hyper::Request<Incoming>,
|
||||
) -> std::result::Result<AppResponse, hyper::Error> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let path = uri.path();
|
||||
|
||||
if method == Method::OPTIONS {
|
||||
let mut res = Response::default();
|
||||
*res.status_mut() = StatusCode::NO_CONTENT;
|
||||
set_cors_header(&mut res);
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
let mut status = StatusCode::OK;
|
||||
let res = if path == "/v1/chat/completions" {
|
||||
self.chat_completions(req).await
|
||||
} else if path == "/v1/embeddings" {
|
||||
self.embeddings(req).await
|
||||
} else if path == "/v1/rerank" {
|
||||
self.rerank(req).await
|
||||
} else if path == "/v1/models" {
|
||||
self.list_models()
|
||||
} else if path == "/v1/roles" {
|
||||
self.list_roles()
|
||||
} else if path == "/v1/rags" {
|
||||
self.list_rags()
|
||||
} else if path == "/v1/rags/search" {
|
||||
self.search_rag(req).await
|
||||
} else if path == "/playground" || path == "/playground.html" {
|
||||
self.playground_page()
|
||||
} else if path == "/arena" || path == "/arena.html" {
|
||||
self.arena_page()
|
||||
} else {
|
||||
status = StatusCode::NOT_FOUND;
|
||||
Err(anyhow!("Not Found"))
|
||||
};
|
||||
let mut res = match res {
|
||||
Ok(res) => {
|
||||
info!("{method} {uri} {}", status.as_u16());
|
||||
res
|
||||
}
|
||||
Err(err) => {
|
||||
if status == StatusCode::OK {
|
||||
status = StatusCode::BAD_REQUEST;
|
||||
}
|
||||
error!("{method} {uri} {} {err}", status.as_u16());
|
||||
ret_err(err)
|
||||
}
|
||||
};
|
||||
*res.status_mut() = status;
|
||||
set_cors_header(&mut res);
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn playground_page(&self) -> Result<AppResponse> {
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "text/html; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn arena_page(&self) -> Result<AppResponse> {
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "text/html; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(ARENA_HTML)).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Result<AppResponse> {
|
||||
let data = json!({ "data": self.models });
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn list_roles(&self) -> Result<AppResponse> {
|
||||
let data = json!({ "data": self.roles });
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn list_rags(&self) -> Result<AppResponse> {
|
||||
let data = json!({ "data": self.rags });
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn search_rag(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
||||
let req_body = req.collect().await?.to_bytes();
|
||||
let req_body: Value = serde_json::from_slice(&req_body)
|
||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
||||
|
||||
debug!("search rag request: {req_body}");
|
||||
let SearchRagReqBody { name, input } = serde_json::from_value(req_body)
|
||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
||||
|
||||
let abort_signal = create_abort_signal();
|
||||
|
||||
let rag_path = config.read().rag_file(&name);
|
||||
let rag = Rag::load(&config, &name, &rag_path)?;
|
||||
|
||||
let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?;
|
||||
|
||||
let data = json!({ "data": rag_result });
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json; charset=utf-8")
|
||||
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
||||
let req_body = req.collect().await?.to_bytes();
|
||||
let req_body: Value = serde_json::from_slice(&req_body)
|
||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
||||
|
||||
debug!("chat completions request: {req_body}");
|
||||
let req_body = serde_json::from_value(req_body)
|
||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let ChatCompletionsReqBody {
|
||||
model,
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
max_tokens,
|
||||
stream,
|
||||
tools,
|
||||
} = req_body;
|
||||
|
||||
let mut messages =
|
||||
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let config = self.config.clone();
|
||||
|
||||
let default_model = config.model.clone();
|
||||
|
||||
let config = Arc::new(RwLock::new(config));
|
||||
|
||||
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
|
||||
(default_model.id(), true)
|
||||
} else if default_model.id() == model {
|
||||
(model, false)
|
||||
} else {
|
||||
(model, true)
|
||||
};
|
||||
|
||||
if change {
|
||||
config.write().set_model(&model_name)?;
|
||||
}
|
||||
|
||||
let mut client = init_client(&config, None)?;
|
||||
if max_tokens.is_some() {
|
||||
client.model_mut().set_max_tokens(max_tokens, true);
|
||||
}
|
||||
let abort_signal = create_abort_signal();
|
||||
let http_client = client.build_client()?;
|
||||
|
||||
let completion_id = generate_completion_id();
|
||||
let created = Utc::now().timestamp();
|
||||
|
||||
patch_messages(&mut messages, client.model());
|
||||
|
||||
let data: ChatCompletionsData = ChatCompletionsData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
functions,
|
||||
stream,
|
||||
};
|
||||
|
||||
if stream {
|
||||
let (tx, mut rx) = unbounded_channel();
|
||||
tokio::spawn(async move {
|
||||
let is_first = Arc::new(AtomicBool::new(true));
|
||||
let (sse_tx, sse_rx) = unbounded_channel();
|
||||
let mut handler = SseHandler::new(sse_tx, abort_signal);
|
||||
async fn map_event(
|
||||
mut sse_rx: UnboundedReceiver<SseEvent>,
|
||||
tx: &UnboundedSender<ResEvent>,
|
||||
is_first: Arc<AtomicBool>,
|
||||
) {
|
||||
while let Some(reply_event) = sse_rx.recv().await {
|
||||
if is_first.load(Ordering::SeqCst) {
|
||||
let _ = tx.send(ResEvent::First(None));
|
||||
is_first.store(false, Ordering::SeqCst)
|
||||
}
|
||||
match reply_event {
|
||||
SseEvent::Text(text) => {
|
||||
let _ = tx.send(ResEvent::Text(text));
|
||||
}
|
||||
SseEvent::Done => {
|
||||
let _ = tx.send(ResEvent::Done);
|
||||
sse_rx.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn chat_completions(
|
||||
client: &dyn Client,
|
||||
http_client: &reqwest::Client,
|
||||
handler: &mut SseHandler,
|
||||
mut data: ChatCompletionsData,
|
||||
tx: &UnboundedSender<ResEvent>,
|
||||
is_first: Arc<AtomicBool>,
|
||||
) {
|
||||
if client.model().no_stream() {
|
||||
data.stream = false;
|
||||
let ret = client.chat_completions_inner(http_client, data).await;
|
||||
match ret {
|
||||
Ok(output) => {
|
||||
let ChatCompletionsOutput {
|
||||
text, tool_calls, ..
|
||||
} = output;
|
||||
let _ = tx.send(ResEvent::First(None));
|
||||
is_first.store(false, Ordering::SeqCst);
|
||||
let _ = tx.send(ResEvent::Text(text));
|
||||
if !tool_calls.is_empty() {
|
||||
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
|
||||
is_first.store(false, Ordering::SeqCst)
|
||||
}
|
||||
};
|
||||
} else {
|
||||
let ret = client
|
||||
.chat_completions_streaming_inner(http_client, handler, data)
|
||||
.await;
|
||||
let first = match ret {
|
||||
Ok(()) => None,
|
||||
Err(err) => Some(format!("{err:?}")),
|
||||
};
|
||||
if is_first.load(Ordering::SeqCst) {
|
||||
let _ = tx.send(ResEvent::First(first));
|
||||
is_first.store(false, Ordering::SeqCst)
|
||||
}
|
||||
let tool_calls = handler.tool_calls().to_vec();
|
||||
if !tool_calls.is_empty() {
|
||||
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
|
||||
}
|
||||
}
|
||||
handler.done();
|
||||
}
|
||||
tokio::join!(
|
||||
map_event(sse_rx, &tx, is_first.clone()),
|
||||
chat_completions(
|
||||
client.as_ref(),
|
||||
&http_client,
|
||||
&mut handler,
|
||||
data,
|
||||
&tx,
|
||||
is_first
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
let first_event = rx.recv().await;
|
||||
|
||||
if let Some(ResEvent::First(Some(err))) = first_event {
|
||||
bail!("{err}");
|
||||
}
|
||||
|
||||
let shared: Arc<(String, String, i64, AtomicBool)> =
|
||||
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let stream = stream.filter_map(move |res_event| {
|
||||
let shared = shared.clone();
|
||||
async move {
|
||||
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
|
||||
match res_event {
|
||||
ResEvent::Text(text) => {
|
||||
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
|
||||
}
|
||||
ResEvent::ToolCalls(tool_calls) => {
|
||||
has_tool_calls.store(true, Ordering::SeqCst);
|
||||
Some(Ok(create_tool_calls_frame(
|
||||
completion_id,
|
||||
model,
|
||||
*created,
|
||||
&tool_calls,
|
||||
)))
|
||||
}
|
||||
ResEvent::Done => Some(Ok(create_done_frame(
|
||||
completion_id,
|
||||
model,
|
||||
*created,
|
||||
has_tool_calls.load(Ordering::SeqCst),
|
||||
))),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
});
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.body(BodyExt::boxed(StreamBody::new(stream)))?;
|
||||
Ok(res)
|
||||
} else {
|
||||
let output = client.chat_completions_inner(&http_client, data).await?;
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json")
|
||||
.body(
|
||||
Full::new(ret_non_stream(
|
||||
&completion_id,
|
||||
&model_name,
|
||||
created,
|
||||
&output,
|
||||
))
|
||||
.boxed(),
|
||||
)?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
||||
let req_body = req.collect().await?.to_bytes();
|
||||
let req_body: Value = serde_json::from_slice(&req_body)
|
||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
||||
|
||||
debug!("embeddings request: {req_body}");
|
||||
let req_body = serde_json::from_value(req_body)
|
||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let EmbeddingsReqBody {
|
||||
input,
|
||||
model: embedding_model_id,
|
||||
} = req_body;
|
||||
|
||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
||||
|
||||
let embedding_model =
|
||||
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
|
||||
|
||||
let texts = match input {
|
||||
EmbeddingsReqBodyInput::Single(v) => vec![v],
|
||||
EmbeddingsReqBodyInput::Multiple(v) => v,
|
||||
};
|
||||
let client = init_client(&config, Some(embedding_model))?;
|
||||
let data = client
|
||||
.embeddings(&EmbeddingsData {
|
||||
query: false,
|
||||
texts,
|
||||
})
|
||||
.await?;
|
||||
let data: Vec<_> = data
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
json!({
|
||||
"object": "embedding",
|
||||
"embedding": v,
|
||||
"index": i,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let output = json!({
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": embedding_model_id,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
});
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
||||
let req_body = req.collect().await?.to_bytes();
|
||||
let req_body: Value = serde_json::from_slice(&req_body)
|
||||
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
||||
|
||||
debug!("rerank request: {req_body}");
|
||||
let req_body = serde_json::from_value(req_body)
|
||||
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
||||
|
||||
let RerankReqBody {
|
||||
model: reranker_model_id,
|
||||
documents,
|
||||
query,
|
||||
top_n,
|
||||
} = req_body;
|
||||
|
||||
let top_n = top_n.unwrap_or(documents.len());
|
||||
|
||||
let config = Arc::new(RwLock::new(self.config.clone()));
|
||||
|
||||
let reranker_model =
|
||||
Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?;
|
||||
|
||||
let client = init_client(&config, Some(reranker_model))?;
|
||||
let data = client
|
||||
.rerank(&RerankData {
|
||||
query,
|
||||
documents: documents.clone(),
|
||||
top_n,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let results: Vec<_> = data
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
json!({
|
||||
"index": v.index,
|
||||
"relevance_score": v.relevance_score,
|
||||
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let output = json!({
|
||||
"id": uuid::Uuid::new_v4().to_string(),
|
||||
"results": results,
|
||||
});
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SearchRagReqBody {
|
||||
name: String,
|
||||
input: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionsReqBody {
|
||||
model: String,
|
||||
messages: Vec<Value>,
|
||||
temperature: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
max_tokens: Option<isize>,
|
||||
#[serde(default)]
|
||||
stream: bool,
|
||||
tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct EmbeddingsReqBody {
|
||||
input: EmbeddingsReqBodyInput,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum EmbeddingsReqBodyInput {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RerankReqBody {
|
||||
documents: Vec<String>,
|
||||
query: String,
|
||||
model: String,
|
||||
top_n: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ResEvent {
|
||||
First(Option<String>),
|
||||
Text(String),
|
||||
ToolCalls(Vec<ToolCall>),
|
||||
Done,
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install CTRL+C signal handler")
|
||||
}
|
||||
|
||||
fn generate_completion_id() -> String {
|
||||
let random_id = Utc::now().nanosecond();
|
||||
format!("chatcmpl-{random_id}")
|
||||
}
|
||||
|
||||
fn set_cors_header(res: &mut AppResponse) {
|
||||
res.headers_mut().insert(
|
||||
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
hyper::header::HeaderValue::from_static("*"),
|
||||
);
|
||||
res.headers_mut().insert(
|
||||
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
|
||||
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
|
||||
);
|
||||
res.headers_mut().insert(
|
||||
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
|
||||
);
|
||||
}
|
||||
|
||||
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
|
||||
let delta = if content.is_empty() {
|
||||
json!({ "role": "assistant", "content": content })
|
||||
} else {
|
||||
json!({ "content": content })
|
||||
};
|
||||
let choice = json!({
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": null,
|
||||
});
|
||||
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
||||
Frame::data(Bytes::from(format!("data: {value}\n\n")))
|
||||
}
|
||||
|
||||
fn create_tool_calls_frame(
|
||||
id: &str,
|
||||
model: &str,
|
||||
created: i64,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> Frame<Bytes> {
|
||||
let chunks = tool_calls
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, call)| {
|
||||
let choice1 = json!({
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": i,
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null
|
||||
});
|
||||
let choice2 = json!({
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": i,
|
||||
"function": {
|
||||
"arguments": call.arguments.to_string(),
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null
|
||||
});
|
||||
vec![
|
||||
build_chat_completion_chunk_json(id, model, created, &choice1),
|
||||
build_chat_completion_chunk_json(id, model, created, &choice2),
|
||||
]
|
||||
})
|
||||
.map(|v| format!("data: {v}\n\n"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
Frame::data(Bytes::from(chunks))
|
||||
}
|
||||
|
||||
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
|
||||
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
|
||||
let choice = json!({
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": finish_reason,
|
||||
});
|
||||
let value = build_chat_completion_chunk_json(id, model, created, &choice);
|
||||
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
|
||||
}
|
||||
|
||||
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
|
||||
json!({
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [choice],
|
||||
})
|
||||
}
|
||||
|
||||
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
|
||||
let id = output.id.as_deref().unwrap_or(id);
|
||||
let input_tokens = output.input_tokens.unwrap_or_default();
|
||||
let output_tokens = output.output_tokens.unwrap_or_default();
|
||||
let total_tokens = input_tokens + output_tokens;
|
||||
let choice = if output.tool_calls.is_empty() {
|
||||
json!({
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": output.text,
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop",
|
||||
})
|
||||
} else {
|
||||
let content = if output.text.is_empty() {
|
||||
Value::Null
|
||||
} else {
|
||||
output.text.clone().into()
|
||||
};
|
||||
let tool_calls: Vec<_> = output
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
json!({
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": call.arguments.to_string(),
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
json!({
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "tool_calls",
|
||||
})
|
||||
};
|
||||
let res_body = json!({
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [choice],
|
||||
"usage": {
|
||||
"prompt_tokens": input_tokens,
|
||||
"completion_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
});
|
||||
Bytes::from(res_body.to_string())
|
||||
}
|
||||
|
||||
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
|
||||
let data = json!({
|
||||
"error": {
|
||||
"message": err.to_string(),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
});
|
||||
Response::builder()
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Full::new(Bytes::from(data.to_string())).boxed())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
|
||||
let mut output = vec![];
|
||||
let mut tool_results = None;
|
||||
for (i, message) in message.into_iter().enumerate() {
|
||||
let err = || anyhow!("Failed to parse '.messages[{i}]'");
|
||||
let role = message["role"].as_str().ok_or_else(err)?;
|
||||
let content = match message.get("content") {
|
||||
Some(value) => {
|
||||
if let Some(value) = value.as_str() {
|
||||
MessageContent::Text(value.to_string())
|
||||
} else if value.is_array() {
|
||||
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
|
||||
MessageContent::Array(value)
|
||||
} else if value.is_null() {
|
||||
MessageContent::Text(String::new())
|
||||
} else {
|
||||
return Err(err());
|
||||
}
|
||||
}
|
||||
None => MessageContent::Text(String::new()),
|
||||
};
|
||||
match role {
|
||||
"system" | "user" => {
|
||||
let role = match role {
|
||||
"system" => MessageRole::System,
|
||||
"user" => MessageRole::User,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
output.push(Message::new(role, content))
|
||||
}
|
||||
"assistant" => {
|
||||
let role = MessageRole::Assistant;
|
||||
match message["tool_calls"].as_array() {
|
||||
Some(tool_calls) => {
|
||||
if tool_results.is_some() {
|
||||
return Err(err());
|
||||
}
|
||||
let mut list = vec![];
|
||||
for tool_call in tool_calls {
|
||||
if let (id, Some(name), Some(arguments)) = (
|
||||
tool_call["id"].as_str().map(|v| v.to_string()),
|
||||
tool_call["function"]["name"].as_str(),
|
||||
tool_call["function"]["arguments"].as_str(),
|
||||
) {
|
||||
let arguments =
|
||||
serde_json::from_str(arguments).map_err(|_| err())?;
|
||||
list.push((id, name.to_string(), arguments));
|
||||
} else {
|
||||
return Err(err());
|
||||
}
|
||||
}
|
||||
tool_results = Some((content.to_text(), list, vec![]));
|
||||
}
|
||||
None => output.push(Message::new(role, content)),
|
||||
}
|
||||
}
|
||||
"tool" => match tool_results.take() {
|
||||
Some((text, tool_calls, mut tool_values)) => {
|
||||
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
|
||||
let content = content.to_text();
|
||||
let value: Value = serde_json::from_str(&content)
|
||||
.ok()
|
||||
.unwrap_or_else(|| content.into());
|
||||
|
||||
tool_values.push((value, tool_call_id));
|
||||
|
||||
if tool_calls.len() == tool_values.len() {
|
||||
let mut list = vec![];
|
||||
for ((id, name, arguments), (value, tool_call_id)) in
|
||||
tool_calls.into_iter().zip(tool_values.into_iter())
|
||||
{
|
||||
if id != tool_call_id {
|
||||
return Err(err());
|
||||
}
|
||||
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
|
||||
}
|
||||
output.push(Message::new(
|
||||
MessageRole::Assistant,
|
||||
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
|
||||
));
|
||||
tool_results = None;
|
||||
} else {
|
||||
tool_results = Some((text, tool_calls, tool_values));
|
||||
}
|
||||
}
|
||||
None => return Err(err()),
|
||||
},
|
||||
_ => {
|
||||
return Err(err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_results.is_some() {
|
||||
bail!("Invalid messages");
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
|
||||
let tools = match tools {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let mut functions = vec![];
|
||||
for (i, tool) in tools.into_iter().enumerate() {
|
||||
if let (Some("function"), Some(function)) = (
|
||||
tool["type"].as_str(),
|
||||
tool["function"]
|
||||
.as_object()
|
||||
.and_then(|v| serde_json::from_value(json!(v)).ok()),
|
||||
) {
|
||||
functions.push(function);
|
||||
} else {
|
||||
bail!("Failed to parse '.tools[{i}]'")
|
||||
}
|
||||
}
|
||||
Ok(Some(functions))
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use anyhow::Result;
|
||||
use crossterm::event::{self, Event, KeyCode, KeyModifiers};
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
pub type AbortSignal = Arc<AbortSignalInner>;
|
||||
|
||||
pub struct AbortSignalInner {
|
||||
ctrlc: AtomicBool,
|
||||
ctrld: AtomicBool,
|
||||
}
|
||||
|
||||
pub fn create_abort_signal() -> AbortSignal {
|
||||
AbortSignalInner::new()
|
||||
}
|
||||
|
||||
impl AbortSignalInner {
|
||||
pub fn new() -> AbortSignal {
|
||||
Arc::new(Self {
|
||||
ctrlc: AtomicBool::new(false),
|
||||
ctrld: AtomicBool::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn aborted(&self) -> bool {
|
||||
if self.aborted_ctrlc() {
|
||||
return true;
|
||||
}
|
||||
if self.aborted_ctrld() {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn aborted_ctrlc(&self) -> bool {
|
||||
self.ctrlc.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn aborted_ctrld(&self) -> bool {
|
||||
self.ctrld.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn reset(&self) {
|
||||
self.ctrlc.store(false, Ordering::SeqCst);
|
||||
self.ctrld.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn set_ctrlc(&self) {
|
||||
self.ctrlc.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn set_ctrld(&self) {
|
||||
self.ctrld.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_abort_signal(abort_signal: &AbortSignal) {
|
||||
loop {
|
||||
if abort_signal.aborted() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_abort_signal(abort_signal: &AbortSignal) -> Result<bool> {
|
||||
if event::poll(Duration::from_millis(25))? {
|
||||
if let Event::Key(key) = event::read()? {
|
||||
match key.code {
|
||||
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
|
||||
abort_signal.set_ctrlc();
|
||||
return Ok(true);
|
||||
}
|
||||
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
|
||||
abort_signal.set_ctrld();
|
||||
return Ok(true);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use anyhow::Context;
|
||||
|
||||
#[cfg(not(any(target_os = "android", target_os = "emscripten")))]
|
||||
mod internal {
|
||||
use arboard::Clipboard;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use std::sync::{LazyLock, Mutex};
|
||||
|
||||
static CLIPBOARD: LazyLock<Mutex<Option<Clipboard>>> =
|
||||
LazyLock::new(|| Mutex::new(Clipboard::new().ok()));
|
||||
|
||||
pub fn set_text(text: &str) -> anyhow::Result<()> {
|
||||
let mut clipboard = CLIPBOARD.lock().unwrap();
|
||||
match clipboard.as_mut() {
|
||||
Some(clipboard) => {
|
||||
clipboard.set_text(text)?;
|
||||
#[cfg(target_os = "linux")]
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
Ok(())
|
||||
}
|
||||
None => set_text_osc52(text),
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to set text to clipboard with OSC52 escape sequence
|
||||
/// Works in many modern terminals, including over SSH.
|
||||
fn set_text_osc52(text: &str) -> anyhow::Result<()> {
|
||||
let encoded = STANDARD.encode(text);
|
||||
let seq = format!("\x1b]52;c;{encoded}\x07");
|
||||
if let Err(e) = std::io::Write::write_all(&mut std::io::stdout(), seq.as_bytes()) {
|
||||
return Err(anyhow::anyhow!("Failed to send OSC52 sequence").context(e));
|
||||
}
|
||||
if let Err(e) = std::io::Write::flush(&mut std::io::stdout()) {
|
||||
return Err(anyhow::anyhow!("Failed to flush OSC52 sequence").context(e));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "emscripten"))]
|
||||
mod internal {
|
||||
pub fn set_text(_text: &str) -> anyhow::Result<()> {
|
||||
Err(anyhow::anyhow!("No clipboard available"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_text(text: &str) -> anyhow::Result<()> {
|
||||
internal::set_text(text).context("Failed to copy")
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
use super::*;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
ffi::OsStr,
|
||||
fs::OpenOptions,
|
||||
io::{self, Write},
|
||||
path::{Path, PathBuf},
|
||||
process::Command,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use dirs::home_dir;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub static SHELL: LazyLock<Shell> = LazyLock::new(detect_shell);
|
||||
|
||||
pub struct Shell {
|
||||
pub name: String,
|
||||
pub cmd: String,
|
||||
pub arg: String,
|
||||
}
|
||||
|
||||
impl Shell {
|
||||
pub fn new(name: &str, cmd: &str, arg: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
cmd: cmd.to_string(),
|
||||
arg: arg.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_shell() -> Shell {
|
||||
let cmd = env::var(get_env_name("shell")).ok().or_else(|| {
|
||||
if cfg!(windows) {
|
||||
if let Ok(ps_module_path) = env::var("PSModulePath") {
|
||||
let ps_module_path = ps_module_path.to_lowercase();
|
||||
if ps_module_path.starts_with(r"c:\users") {
|
||||
return if ps_module_path.contains(r"\powershell\7\") {
|
||||
Some("pwsh.exe".to_string())
|
||||
} else {
|
||||
Some("powershell.exe".to_string())
|
||||
};
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
env::var("SHELL").ok()
|
||||
}
|
||||
});
|
||||
let name = cmd
|
||||
.as_ref()
|
||||
.and_then(|v| Path::new(v).file_stem().and_then(|v| v.to_str()))
|
||||
.map(|v| {
|
||||
if v == "nu" {
|
||||
"nushell".into()
|
||||
} else {
|
||||
v.to_lowercase()
|
||||
}
|
||||
});
|
||||
let (cmd, name) = match (cmd.as_deref(), name.as_deref()) {
|
||||
(Some(cmd), Some(name)) => (cmd, name),
|
||||
_ => {
|
||||
if cfg!(windows) {
|
||||
("cmd.exe", "cmd")
|
||||
} else {
|
||||
("/bin/sh", "sh")
|
||||
}
|
||||
}
|
||||
};
|
||||
let shell_arg = match name {
|
||||
"powershell" => "-Command",
|
||||
"cmd" => "/C",
|
||||
_ => "-c",
|
||||
};
|
||||
Shell::new(name, cmd, shell_arg)
|
||||
}
|
||||
|
||||
pub fn run_command<T: AsRef<OsStr>>(
|
||||
cmd: &str,
|
||||
args: &[T],
|
||||
envs: Option<HashMap<String, String>>,
|
||||
) -> Result<i32> {
|
||||
let status = Command::new(cmd)
|
||||
.args(args.iter())
|
||||
.envs(envs.unwrap_or_default())
|
||||
.status()?;
|
||||
Ok(status.code().unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn run_command_with_output<T: AsRef<OsStr>>(
|
||||
cmd: &str,
|
||||
args: &[T],
|
||||
envs: Option<HashMap<String, String>>,
|
||||
) -> Result<(bool, String, String)> {
|
||||
let output = Command::new(cmd)
|
||||
.args(args.iter())
|
||||
.envs(envs.unwrap_or_default())
|
||||
.output()?;
|
||||
let status = output.status;
|
||||
let stdout = std::str::from_utf8(&output.stdout).context("Invalid UTF-8 in stdout")?;
|
||||
let stderr = std::str::from_utf8(&output.stderr).context("Invalid UTF-8 in stderr")?;
|
||||
|
||||
if !status.success() {
|
||||
debug!("Command `{cmd}` exited with non-zero: {status}");
|
||||
}
|
||||
|
||||
if !stdout.is_empty() {
|
||||
debug!("Command `{cmd}` exited with non-zero. stderr: {stderr}");
|
||||
}
|
||||
|
||||
if !stderr.is_empty() {
|
||||
debug!("Command `{cmd}` executed successfully. stdout: {stdout}");
|
||||
}
|
||||
|
||||
Ok((status.success(), stdout.to_string(), stderr.to_string()))
|
||||
}
|
||||
|
||||
pub fn run_loader_command(path: &str, extension: &str, loader_command: &str) -> Result<String> {
|
||||
let cmd_args = shell_words::split(loader_command)
|
||||
.with_context(|| anyhow!("Invalid document loader '{extension}': `{loader_command}`"))?;
|
||||
let mut use_stdout = true;
|
||||
let outpath = temp_file("-output-", "").display().to_string();
|
||||
let cmd_args: Vec<_> = cmd_args
|
||||
.into_iter()
|
||||
.map(|mut v| {
|
||||
if v.contains("$1") {
|
||||
v = v.replace("$1", path);
|
||||
}
|
||||
if v.contains("$2") {
|
||||
use_stdout = false;
|
||||
v = v.replace("$2", &outpath);
|
||||
}
|
||||
v
|
||||
})
|
||||
.collect();
|
||||
let cmd_eval = shell_words::join(&cmd_args);
|
||||
debug!("run `{cmd_eval}`");
|
||||
let (cmd, args) = cmd_args.split_at(1);
|
||||
let cmd = &cmd[0];
|
||||
if use_stdout {
|
||||
let (success, stdout, stderr) =
|
||||
run_command_with_output(cmd, args, None).with_context(|| {
|
||||
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
|
||||
})?;
|
||||
if !success {
|
||||
let err = if !stderr.is_empty() {
|
||||
stderr
|
||||
} else {
|
||||
format!("The command `{cmd_eval}` exited with non-zero.")
|
||||
};
|
||||
bail!("{err}")
|
||||
}
|
||||
Ok(stdout)
|
||||
} else {
|
||||
let status = run_command(cmd, args, None).with_context(|| {
|
||||
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
|
||||
})?;
|
||||
if status != 0 {
|
||||
bail!("The command `{cmd_eval}` exited with non-zero.")
|
||||
}
|
||||
let contents = std::fs::read_to_string(&outpath)
|
||||
.context("Failed to read file generated by the loader")?;
|
||||
Ok(contents)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_file(editor: &str, path: &Path) -> Result<()> {
|
||||
let mut child = Command::new(editor).arg(path).spawn()?;
|
||||
child.wait()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_to_shell_history(shell: &str, command: &str, exit_code: i32) -> io::Result<()> {
|
||||
if let Some(history_file) = get_history_file(shell) {
|
||||
let command = command.replace('\n', " ");
|
||||
let now = now_timestamp();
|
||||
let history_txt = if shell == "fish" {
|
||||
format!("- cmd: {command}\n when: {now}")
|
||||
} else if shell == "zsh" {
|
||||
format!(": {now}:{exit_code};{command}",)
|
||||
} else {
|
||||
command
|
||||
};
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&history_file)?;
|
||||
writeln!(file, "{history_txt}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_history_file(shell: &str) -> Option<PathBuf> {
|
||||
match shell {
|
||||
"bash" | "sh" => env::var("HISTFILE")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.or(Some(home_dir()?.join(".bash_history"))),
|
||||
"zsh" => env::var("HISTFILE")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.or(Some(home_dir()?.join(".zsh_history"))),
|
||||
"nushell" => Some(dirs::config_dir()?.join("nushell").join("history.txt")),
|
||||
"fish" => Some(
|
||||
home_dir()?
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("fish")
|
||||
.join("fish_history"),
|
||||
),
|
||||
"powershell" | "pwsh" => {
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
Some(
|
||||
home_dir()?
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("powershell")
|
||||
.join("PSReadLine")
|
||||
.join("ConsoleHost_history.txt"),
|
||||
)
|
||||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
Some(
|
||||
dirs::data_dir()?
|
||||
.join("Microsoft")
|
||||
.join("Windows")
|
||||
.join("PowerShell")
|
||||
.join("PSReadLine")
|
||||
.join("ConsoleHost_history.txt"),
|
||||
)
|
||||
}
|
||||
}
|
||||
"ksh" => Some(home_dir()?.join(".ksh_history")),
|
||||
"tcsh" => Some(home_dir()?.join(".history")),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
pub fn sha256(input: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(input);
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
pub fn hmac_sha256(key: &[u8], msg: &str) -> Vec<u8> {
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
|
||||
mac.update(msg.as_bytes());
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
pub fn hex_encode(bytes: &[u8]) -> String {
|
||||
bytes
|
||||
.iter()
|
||||
.fold(String::new(), |acc, b| acc + &format!("{b:02x}"))
|
||||
}
|
||||
|
||||
pub fn encode_uri(uri: &str) -> String {
|
||||
uri.split('/')
|
||||
.map(|v| urlencoding::encode(v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("/")
|
||||
}
|
||||
|
||||
pub fn base64_encode<T: AsRef<[u8]>>(input: T) -> String {
|
||||
STANDARD.encode(input)
|
||||
}
|
||||
pub fn base64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
STANDARD.decode(input)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use html_to_markdown::{markdown, TagHandler};
|
||||
|
||||
pub fn html_to_md(html: &str) -> String {
|
||||
let mut handlers: Vec<TagHandler> = vec![
|
||||
Rc::new(RefCell::new(markdown::ParagraphHandler)),
|
||||
Rc::new(RefCell::new(markdown::HeadingHandler)),
|
||||
Rc::new(RefCell::new(markdown::ListHandler)),
|
||||
Rc::new(RefCell::new(markdown::TableHandler::new())),
|
||||
Rc::new(RefCell::new(markdown::StyledTextHandler)),
|
||||
Rc::new(RefCell::new(markdown::CodeHandler)),
|
||||
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
|
||||
];
|
||||
|
||||
html_to_markdown::convert_html_to_markdown(html.as_bytes(), &mut handlers)
|
||||
.unwrap_or_else(|_| html.to_string())
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
use anyhow::Result;
|
||||
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
|
||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
|
||||
use std::io::{stdout, Write};
|
||||
|
||||
/// Reads a single character from stdin without requiring Enter
|
||||
/// Returns the character if it's one of the valid options, or the default if Enter is pressed
|
||||
pub fn read_single_key(valid_chars: &[char], default: char, prompt: &str) -> Result<char> {
|
||||
print!("{prompt}");
|
||||
stdout().flush()?;
|
||||
|
||||
enable_raw_mode()?;
|
||||
|
||||
let result = loop {
|
||||
if let Ok(Event::Key(KeyEvent {
|
||||
code, modifiers, ..
|
||||
})) = event::read()
|
||||
{
|
||||
match code {
|
||||
KeyCode::Char('c') if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
break Err(anyhow::anyhow!("Interrupted"));
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
if valid_chars.contains(&c) {
|
||||
break Ok(c);
|
||||
}
|
||||
// Invalid character, continue loop
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
break Ok(default);
|
||||
}
|
||||
_ => {
|
||||
// Other keys are ignored, continue loop
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
disable_raw_mode()?;
|
||||
|
||||
// Print the chosen character and newline for clean output
|
||||
if let Ok(chosen) = &result {
|
||||
println!("{chosen}");
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
use super::*;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use indexmap::IndexMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const EXTENSION_METADATA: &str = "__extension__";
|
||||
|
||||
pub type DocumentMetadata = IndexMap<String, String>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoadedDocument {
|
||||
pub path: String,
|
||||
pub contents: String,
|
||||
#[serde(default)]
|
||||
pub metadata: DocumentMetadata,
|
||||
}
|
||||
|
||||
impl LoadedDocument {
|
||||
pub fn new(path: String, contents: String, metadata: DocumentMetadata) -> Self {
|
||||
Self {
|
||||
path,
|
||||
contents,
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_recursive_url(
|
||||
loaders: &HashMap<String, String>,
|
||||
path: &str,
|
||||
) -> Result<Vec<LoadedDocument>> {
|
||||
let extension = RECURSIVE_URL_LOADER;
|
||||
let pages: Vec<Page> = match loaders.get(extension) {
|
||||
Some(loader_command) => {
|
||||
let contents = run_loader_command(path, extension, loader_command)?;
|
||||
serde_json::from_str(&contents).context(r#"The crawler response is invalid. It should follow the JSON format: `[{"path":"...", "text":"..."}]`."#)?
|
||||
}
|
||||
None => {
|
||||
let options = CrawlOptions::preset(path);
|
||||
crawl_website(path, options).await?
|
||||
}
|
||||
};
|
||||
let output = pages
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let Page { path, text } = v;
|
||||
let mut metadata: DocumentMetadata = Default::default();
|
||||
metadata.insert(EXTENSION_METADATA.into(), "md".into());
|
||||
LoadedDocument::new(path, text, metadata)
|
||||
})
|
||||
.collect();
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub async fn load_file(loaders: &HashMap<String, String>, path: &str) -> Result<LoadedDocument> {
|
||||
let extension = get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into());
|
||||
match loaders.get(&extension) {
|
||||
Some(loader_command) => load_with_command(path, &extension, loader_command),
|
||||
None => load_plain(path, &extension).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_url(loaders: &HashMap<String, String>, path: &str) -> Result<LoadedDocument> {
|
||||
let (contents, extension) = fetch_with_loaders(loaders, path, false).await?;
|
||||
let mut metadata: DocumentMetadata = Default::default();
|
||||
metadata.insert(EXTENSION_METADATA.into(), extension);
|
||||
Ok(LoadedDocument::new(path.into(), contents, metadata))
|
||||
}
|
||||
|
||||
async fn load_plain(path: &str, extension: &str) -> Result<LoadedDocument> {
|
||||
let contents = tokio::fs::read_to_string(path).await?;
|
||||
let mut metadata: DocumentMetadata = Default::default();
|
||||
metadata.insert(EXTENSION_METADATA.into(), extension.to_string());
|
||||
Ok(LoadedDocument::new(path.into(), contents, metadata))
|
||||
}
|
||||
|
||||
fn load_with_command(path: &str, extension: &str, loader_command: &str) -> Result<LoadedDocument> {
|
||||
let contents = run_loader_command(path, extension, loader_command)?;
|
||||
let mut metadata: DocumentMetadata = Default::default();
|
||||
metadata.insert(EXTENSION_METADATA.into(), DEFAULT_EXTENSION.to_string());
|
||||
Ok(LoadedDocument::new(path.into(), contents, metadata))
|
||||
}
|
||||
|
||||
pub fn is_loader_protocol(loaders: &HashMap<String, String>, path: &str) -> bool {
|
||||
match path.split_once(':') {
|
||||
Some((protocol, _)) => loaders.contains_key(protocol),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_protocol_path(
|
||||
loaders: &HashMap<String, String>,
|
||||
path: &str,
|
||||
) -> Result<Vec<LoadedDocument>> {
|
||||
let (protocol, loader_command, new_path) = path
|
||||
.split_once(':')
|
||||
.and_then(|(protocol, path)| {
|
||||
let loader_command = loaders.get(protocol)?;
|
||||
Some((protocol, loader_command, path))
|
||||
})
|
||||
.ok_or_else(|| anyhow!("No document loader for '{}'", path))?;
|
||||
let contents = run_loader_command(new_path, protocol, loader_command)?;
|
||||
let output = if let Ok(list) = serde_json::from_str::<Vec<LoadedDocument>>(&contents) {
|
||||
list.into_iter()
|
||||
.map(|mut v| {
|
||||
if v.path.starts_with(path) {
|
||||
} else if v.path.starts_with(new_path) {
|
||||
v.path = format!("{}:{}", protocol, v.path);
|
||||
} else {
|
||||
v.path = format!("{}/{}", path, v.path);
|
||||
}
|
||||
v
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![LoadedDocument::new(
|
||||
path.into(),
|
||||
contents,
|
||||
Default::default(),
|
||||
)]
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
use crate::config::Config;
|
||||
use colored::Colorize;
|
||||
use fancy_regex::Regex;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Seek, SeekFrom};
|
||||
use std::process;
|
||||
|
||||
pub async fn tail_logs(no_color: bool) {
|
||||
let re = Regex::new(r"^(?P<timestamp>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\s+<(?P<opid>[^\s>]+)>\s+\[(?P<level>[A-Z]+)\]\s+(?P<logger>[^:]+):(?P<line>\d+)\s+-\s+(?P<message>.*)$").unwrap();
|
||||
let file_path = Config::log_path();
|
||||
let file = File::open(&file_path).expect("Cannot open file");
|
||||
let mut reader = BufReader::new(file);
|
||||
|
||||
if let Err(e) = reader.seek(SeekFrom::End(0)) {
|
||||
eprintln!("Unable to tail log file: {e:?}");
|
||||
process::exit(1);
|
||||
};
|
||||
|
||||
let mut lines = reader.lines();
|
||||
|
||||
loop {
|
||||
if let Some(Ok(line)) = lines.next() {
|
||||
if no_color {
|
||||
println!("{line}");
|
||||
} else {
|
||||
let colored_line = colorize_log_line(&line, &re);
|
||||
println!("{colored_line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn colorize_log_line(line: &str, re: &Regex) -> String {
|
||||
if let Some(caps) = re.captures(line).expect("Failed to capture log line") {
|
||||
let level = &caps["level"];
|
||||
let message = &caps["message"];
|
||||
|
||||
let colored_message = match level {
|
||||
"ERROR" => message.red(),
|
||||
"WARN" => message.yellow(),
|
||||
"INFO" => message.green(),
|
||||
"DEBUG" => message.blue(),
|
||||
_ => message.normal(),
|
||||
};
|
||||
|
||||
let timestamp = &caps["timestamp"];
|
||||
let opid = &caps["opid"];
|
||||
let logger = &caps["logger"];
|
||||
let line_number = &caps["line"];
|
||||
|
||||
format!(
|
||||
"{} <{}> [{}] {}:{} - {}",
|
||||
timestamp.white(),
|
||||
opid.cyan(),
|
||||
level.bold(),
|
||||
logger.magenta(),
|
||||
line_number.bold(),
|
||||
colored_message
|
||||
)
|
||||
} else {
|
||||
line.to_string()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
mod abort_signal;
|
||||
mod clipboard;
|
||||
mod command;
|
||||
mod crypto;
|
||||
mod html_to_md;
|
||||
mod input;
|
||||
mod loader;
|
||||
mod logs;
|
||||
pub mod native;
|
||||
mod path;
|
||||
mod render_prompt;
|
||||
mod request;
|
||||
mod spinner;
|
||||
mod variables;
|
||||
|
||||
pub use self::abort_signal::*;
|
||||
pub use self::clipboard::set_text;
|
||||
pub use self::command::*;
|
||||
pub use self::crypto::*;
|
||||
pub use self::html_to_md::*;
|
||||
pub use self::input::*;
|
||||
pub use self::loader::*;
|
||||
pub use self::logs::*;
|
||||
pub use self::path::*;
|
||||
pub use self::render_prompt::render_prompt;
|
||||
pub use self::request::*;
|
||||
pub use self::spinner::*;
|
||||
pub use self::variables::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use fuzzy_matcher::{skim::SkimMatcherV2, FuzzyMatcher};
|
||||
use is_terminal::IsTerminal;
|
||||
use std::borrow::Cow;
|
||||
use std::sync::LazyLock;
|
||||
use std::{env, path::PathBuf, process};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
pub static CODE_BLOCK_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?ms)```\w*(.*)```").unwrap());
|
||||
pub static THINK_TAG_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?s)^\s*<think>.*?</think>(\s*|$)").unwrap());
|
||||
pub static IS_STDOUT_TERMINAL: LazyLock<bool> = LazyLock::new(|| std::io::stdout().is_terminal());
|
||||
pub static NO_COLOR: LazyLock<bool> = LazyLock::new(|| {
|
||||
env::var("NO_COLOR")
|
||||
.ok()
|
||||
.and_then(|v| parse_bool(&v))
|
||||
.unwrap_or_default()
|
||||
|| !*IS_STDOUT_TERMINAL
|
||||
});
|
||||
|
||||
pub fn now() -> String {
|
||||
chrono::Local::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, false)
|
||||
}
|
||||
|
||||
pub fn now_timestamp() -> i64 {
|
||||
chrono::Local::now().timestamp()
|
||||
}
|
||||
|
||||
pub fn get_env_name(key: &str) -> String {
|
||||
format!("{}_{key}", env!("CARGO_CRATE_NAME"),).to_ascii_uppercase()
|
||||
}
|
||||
|
||||
pub fn normalize_env_name(value: &str) -> String {
|
||||
value.replace('-', "_").to_ascii_uppercase()
|
||||
}
|
||||
|
||||
pub fn parse_bool(value: &str) -> Option<bool> {
|
||||
match value {
|
||||
"1" | "true" => Some(true),
|
||||
"0" | "false" => Some(false),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn estimate_token_length(text: &str) -> usize {
|
||||
let words: Vec<&str> = text.unicode_words().collect();
|
||||
let mut output: f32 = 0.0;
|
||||
for word in words {
|
||||
if word.is_ascii() {
|
||||
output += 1.3;
|
||||
} else {
|
||||
let count = word.chars().count();
|
||||
if count == 1 {
|
||||
output += 1.0
|
||||
} else {
|
||||
output += (count as f32) * 0.5;
|
||||
}
|
||||
}
|
||||
}
|
||||
output.ceil() as usize
|
||||
}
|
||||
|
||||
pub fn strip_think_tag(text: &str) -> Cow<'_, str> {
|
||||
THINK_TAG_RE.replace_all(text, "")
|
||||
}
|
||||
|
||||
pub fn extract_code_block(text: &str) -> &str {
|
||||
CODE_BLOCK_RE
|
||||
.captures(text)
|
||||
.ok()
|
||||
.and_then(|v| v?.get(1).map(|v| v.as_str().trim()))
|
||||
.unwrap_or(text)
|
||||
}
|
||||
|
||||
pub fn convert_option_string(value: &str) -> Option<String> {
|
||||
if value.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fuzzy_filter<T, F>(values: Vec<T>, get: F, pattern: &str) -> Vec<T>
|
||||
where
|
||||
F: Fn(&T) -> &str,
|
||||
{
|
||||
let matcher = SkimMatcherV2::default();
|
||||
let mut list: Vec<(T, i64)> = values
|
||||
.into_iter()
|
||||
.filter_map(|v| {
|
||||
let score = matcher.fuzzy_match(get(&v), pattern)?;
|
||||
Some((v, score))
|
||||
})
|
||||
.collect();
|
||||
list.sort_unstable_by(|a, b| b.1.cmp(&a.1));
|
||||
list.into_iter().map(|(v, _)| v).collect()
|
||||
}
|
||||
|
||||
pub fn pretty_error(err: &anyhow::Error) -> String {
|
||||
let mut output = vec![];
|
||||
output.push(format!("Error: {err}"));
|
||||
let causes: Vec<_> = err.chain().skip(1).collect();
|
||||
let causes_len = causes.len();
|
||||
if causes_len > 0 {
|
||||
output.push("\nCaused by:".to_string());
|
||||
if causes_len == 1 {
|
||||
output.push(format!(" {}", indent_text(causes[0], 4).trim()));
|
||||
} else {
|
||||
for (i, cause) in causes.into_iter().enumerate() {
|
||||
output.push(format!("{i:5}: {}", indent_text(cause, 7).trim()));
|
||||
}
|
||||
}
|
||||
}
|
||||
output.join("\n")
|
||||
}
|
||||
|
||||
pub fn indent_text<T: ToString>(s: T, size: usize) -> String {
|
||||
let indent_str = " ".repeat(size);
|
||||
s.to_string()
|
||||
.split('\n')
|
||||
.map(|line| format!("{indent_str}{line}"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn error_text(input: &str) -> String {
|
||||
color_text(input, nu_ansi_term::Color::Red)
|
||||
}
|
||||
|
||||
pub fn warning_text(input: &str) -> String {
|
||||
color_text(input, nu_ansi_term::Color::Yellow)
|
||||
}
|
||||
|
||||
pub fn color_text(input: &str, color: nu_ansi_term::Color) -> String {
|
||||
if *NO_COLOR {
|
||||
return input.to_string();
|
||||
}
|
||||
nu_ansi_term::Style::new()
|
||||
.fg(color)
|
||||
.paint(input)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub fn dimmed_text(input: &str) -> String {
|
||||
if *NO_COLOR {
|
||||
return input.to_string();
|
||||
}
|
||||
nu_ansi_term::Style::new().dimmed().paint(input).to_string()
|
||||
}
|
||||
|
||||
pub fn multiline_text(input: &str) -> String {
|
||||
input
|
||||
.split('\n')
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
if i == 0 {
|
||||
v.to_string()
|
||||
} else {
|
||||
format!(".. {v}")
|
||||
}
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn temp_file(prefix: &str, suffix: &str) -> PathBuf {
|
||||
env::temp_dir().join(format!(
|
||||
"{}-{}{prefix}{}{suffix}",
|
||||
env!("CARGO_CRATE_NAME").to_lowercase(),
|
||||
process::id(),
|
||||
uuid::Uuid::new_v4()
|
||||
))
|
||||
}
|
||||
|
||||
pub fn is_url(path: &str) -> bool {
|
||||
path.starts_with("http://") || path.starts_with("https://")
|
||||
}
|
||||
|
||||
pub fn set_proxy(
|
||||
mut builder: reqwest::ClientBuilder,
|
||||
proxy: &str,
|
||||
) -> Result<reqwest::ClientBuilder> {
|
||||
builder = builder.no_proxy();
|
||||
if !proxy.is_empty() && proxy != "-" {
|
||||
builder = builder
|
||||
.proxy(reqwest::Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
|
||||
};
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
pub fn decode_bin<T: serde::de::DeserializeOwned>(data: &[u8]) -> Result<T> {
|
||||
let (v, _) = bincode::serde::decode_from_slice(data, bincode::config::legacy())?;
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn test_safe_join_path() {
|
||||
assert_eq!(
|
||||
safe_join_path("/home/user/dir1", "files/file1"),
|
||||
Some(PathBuf::from("/home/user/dir1/files/file1"))
|
||||
);
|
||||
assert!(safe_join_path("/home/user/dir1", "/files/file1").is_none());
|
||||
assert!(safe_join_path("/home/user/dir1", "../file1").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "windows")]
|
||||
fn test_safe_join_path() {
|
||||
assert_eq!(
|
||||
safe_join_path("C:\\Users\\user\\dir1", "files/file1"),
|
||||
Some(PathBuf::from("C:\\Users\\user\\dir1\\files\\file1"))
|
||||
);
|
||||
assert!(safe_join_path("C:\\Users\\user\\dir1", "/files/file1").is_none());
|
||||
assert!(safe_join_path("C:\\Users\\user\\dir1", "../file1").is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
#[cfg(windows)]
|
||||
pub mod runtime {
|
||||
use std::path::Path;
|
||||
|
||||
pub fn bash_path() -> Option<String> {
|
||||
let bash_path = "C:\\Program Files\\Git\\bin\\bash.exe";
|
||||
if exist_path(bash_path) {
|
||||
return Some(bash_path.into());
|
||||
}
|
||||
let git_path = which("git")?;
|
||||
let git_parent_path = parent_path(&git_path)?;
|
||||
let bash_path = join_path(&parent_path(&git_parent_path)?, &["bin", "bash.exe"]);
|
||||
if exist_path(&bash_path) {
|
||||
return Some(bash_path);
|
||||
}
|
||||
let bash_path = join_path(&git_parent_path, &["bash.exe"]);
|
||||
if exist_path(&bash_path) {
|
||||
return Some(bash_path);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn exist_path(path: &str) -> bool {
|
||||
Path::new(path).exists()
|
||||
}
|
||||
|
||||
pub fn which(name: &str) -> Option<String> {
|
||||
which::which(name)
|
||||
.ok()
|
||||
.map(|path| path.to_string_lossy().into())
|
||||
}
|
||||
|
||||
fn parent_path(path: &str) -> Option<String> {
|
||||
Path::new(path)
|
||||
.parent()
|
||||
.map(|path| path.to_string_lossy().into())
|
||||
}
|
||||
|
||||
fn join_path(path: &str, parts: &[&str]) -> String {
|
||||
let mut path = Path::new(path).to_path_buf();
|
||||
for part in parts {
|
||||
path = path.join(part);
|
||||
}
|
||||
path.to_string_lossy().into()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
use std::fs;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use fancy_regex::Regex;
|
||||
use indexmap::IndexSet;
|
||||
use path_absolutize::Absolutize;
|
||||
|
||||
type ParseGlobResult = (String, Option<Vec<String>>, bool, Option<usize>);
|
||||
|
||||
pub fn safe_join_path<T1: AsRef<Path>, T2: AsRef<Path>>(
|
||||
base_path: T1,
|
||||
sub_path: T2,
|
||||
) -> Option<PathBuf> {
|
||||
let base_path = base_path.as_ref();
|
||||
let sub_path = sub_path.as_ref();
|
||||
if sub_path.is_absolute() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut joined_path = PathBuf::from(base_path);
|
||||
|
||||
for component in sub_path.components() {
|
||||
if Component::ParentDir == component {
|
||||
return None;
|
||||
}
|
||||
joined_path.push(component);
|
||||
}
|
||||
|
||||
if joined_path.starts_with(base_path) {
|
||||
Some(joined_path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn expand_glob_paths<T: AsRef<str>>(
|
||||
paths: &[T],
|
||||
bail_non_exist: bool,
|
||||
) -> Result<IndexSet<String>> {
|
||||
let mut new_paths = IndexSet::new();
|
||||
for path in paths {
|
||||
let (path_str, suffixes, current_only, depth) = parse_glob(path.as_ref())?;
|
||||
list_files(
|
||||
&mut new_paths,
|
||||
Path::new(&path_str),
|
||||
suffixes.as_ref(),
|
||||
current_only,
|
||||
bail_non_exist,
|
||||
depth,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Ok(new_paths)
|
||||
}
|
||||
|
||||
pub fn clear_dir(dir: &Path) -> Result<()> {
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_dir() {
|
||||
fs::remove_dir_all(&path)?;
|
||||
} else {
|
||||
fs::remove_file(&path)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn list_file_names<T: AsRef<Path>>(dir: T, ext: &str) -> Vec<String> {
|
||||
match fs::read_dir(dir.as_ref()) {
|
||||
Ok(rd) => {
|
||||
let mut names = vec![];
|
||||
for entry in rd.flatten() {
|
||||
let name = entry.file_name();
|
||||
if let Some(name) = name.to_string_lossy().strip_suffix(ext) {
|
||||
names.push(name.to_string());
|
||||
}
|
||||
}
|
||||
names.sort_unstable();
|
||||
names
|
||||
}
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_patch_extension(path: &str) -> Option<String> {
|
||||
Path::new(&path)
|
||||
.extension()
|
||||
.map(|v| v.to_string_lossy().to_lowercase())
|
||||
}
|
||||
|
||||
pub fn to_absolute_path(path: &str) -> Result<String> {
|
||||
Ok(Path::new(&path).absolutize()?.display().to_string())
|
||||
}
|
||||
|
||||
pub fn resolve_home_dir(path: &str) -> String {
|
||||
let mut path = path.to_string();
|
||||
if path.starts_with("~/") || path.starts_with("~\\") {
|
||||
if let Some(home_dir) = dirs::home_dir() {
|
||||
path.replace_range(..1, &home_dir.display().to_string());
|
||||
}
|
||||
}
|
||||
path
|
||||
}
|
||||
|
||||
fn parse_glob(path_str: &str) -> Result<ParseGlobResult> {
|
||||
let globbed_single_subdir_regex = Regex::new(r"\*/[^/]+\.[^/]+$").expect("invalid regex");
|
||||
let globbed_recursive_subdir_regex = Regex::new(r"\*\*/[^/]+\.[^/]+$").expect("invalid regex");
|
||||
let glob_result =
|
||||
if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) {
|
||||
Some((start, 6, false, None))
|
||||
} else if let Some(start) = path_str.find("**/*.").or_else(|| path_str.find(r"**\*.")) {
|
||||
if start == 0 {
|
||||
Some((start, 5, false, None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else if let Some(m) = globbed_recursive_subdir_regex.find(path_str)? {
|
||||
Some((m.start(), 3, false, None))
|
||||
} else if let Some(m) = globbed_single_subdir_regex.find(path_str)? {
|
||||
Some((m.start(), 2, false, Some(1usize)))
|
||||
} else if let Some(start) = path_str.find("/*.").or_else(|| path_str.find(r"\*.")) {
|
||||
Some((start, 3, true, None))
|
||||
} else if let Some(start) = path_str.find("*.") {
|
||||
if start == 0 {
|
||||
Some((start, 2, true, None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some((start, offset, current_only, depth)) = glob_result {
|
||||
let mut base_path = path_str[..start].to_string();
|
||||
if base_path.is_empty() {
|
||||
base_path = if path_str
|
||||
.chars()
|
||||
.next()
|
||||
.map(|v| v == '/')
|
||||
.unwrap_or_default()
|
||||
{
|
||||
"/"
|
||||
} else {
|
||||
"."
|
||||
}
|
||||
.into();
|
||||
}
|
||||
|
||||
let extensions = if let Some(curly_brace_end) = path_str[start..].find('}') {
|
||||
let end = start + curly_brace_end;
|
||||
let extensions_str = &path_str[start + offset..end + 1];
|
||||
if extensions_str.starts_with('{') && extensions_str.ends_with('}') {
|
||||
extensions_str[1..extensions_str.len() - 1]
|
||||
.split(',')
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
} else {
|
||||
bail!("Invalid path '{path_str}'");
|
||||
}
|
||||
} else {
|
||||
let extensions_str = &path_str[start + offset..];
|
||||
vec![extensions_str.to_string()]
|
||||
};
|
||||
let extensions = if extensions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(extensions)
|
||||
};
|
||||
Ok((base_path, extensions, current_only, depth))
|
||||
} else if path_str.ends_with("/**") || path_str.ends_with(r"\**") {
|
||||
Ok((
|
||||
path_str[0..path_str.len() - 3].to_string(),
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
))
|
||||
} else {
|
||||
Ok((path_str.to_string(), None, false, None))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn list_files(
|
||||
files: &mut IndexSet<String>,
|
||||
entry_path: &Path,
|
||||
suffixes: Option<&Vec<String>>,
|
||||
current_only: bool,
|
||||
bail_non_exist: bool,
|
||||
depth: Option<usize>,
|
||||
) -> Result<()> {
|
||||
if !entry_path.exists() {
|
||||
if bail_non_exist {
|
||||
bail!("Not found '{}'", entry_path.display());
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
if entry_path.is_dir() {
|
||||
let mut reader = tokio::fs::read_dir(entry_path).await?;
|
||||
while let Some(entry) = reader.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
if !current_only {
|
||||
if let Some(remaining_depth) = depth {
|
||||
if remaining_depth > 0 {
|
||||
list_files(
|
||||
files,
|
||||
&path,
|
||||
suffixes,
|
||||
current_only,
|
||||
bail_non_exist,
|
||||
Some(remaining_depth - 1),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
} else {
|
||||
list_files(files, &path, suffixes, current_only, bail_non_exist, None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
add_file(files, suffixes, &path);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
add_file(files, suffixes, entry_path);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_file(files: &mut IndexSet<String>, suffixes: Option<&Vec<String>>, path: &Path) {
|
||||
if is_valid_extension(suffixes, path) {
|
||||
let path = path.display().to_string();
|
||||
if !files.contains(&path) {
|
||||
files.insert(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_valid_extension(suffixes: Option<&Vec<String>>, path: &Path) -> bool {
|
||||
let filename_regex = Regex::new(r"^.+\.*").unwrap();
|
||||
if let Some(suffixes) = suffixes {
|
||||
if !suffixes.is_empty() {
|
||||
if let Ok(Some(_)) = filename_regex.find(&suffixes.join(",")) {
|
||||
let file_name = path
|
||||
.file_name()
|
||||
.and_then(|v| v.to_str())
|
||||
.expect("invalid filename")
|
||||
.to_string();
|
||||
return suffixes.contains(&file_name);
|
||||
} else if let Some(extension) =
|
||||
path.extension().map(|v| v.to_string_lossy().to_string())
|
||||
{
|
||||
return suffixes.contains(&extension);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_glob() {
|
||||
assert_eq!(
|
||||
parse_glob("dir").unwrap(),
|
||||
("dir".into(), None, false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/**").unwrap(),
|
||||
("dir".into(), None, false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/file.md").unwrap(),
|
||||
("dir/file.md".into(), None, false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("**/*.md").unwrap(),
|
||||
(".".into(), Some(vec!["md".into()]), false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("/**/*.md").unwrap(),
|
||||
("/".into(), Some(vec!["md".into()]), false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/**/*.md").unwrap(),
|
||||
("dir".into(), Some(vec!["md".into()]), false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/**/test.md").unwrap(),
|
||||
("dir/".into(), Some(vec!["test.md".into()]), false, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/*/test.md").unwrap(),
|
||||
(
|
||||
"dir/".into(),
|
||||
Some(vec!["test.md".into()]),
|
||||
false,
|
||||
Some(1usize)
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/**/*.{md,txt}").unwrap(),
|
||||
(
|
||||
"dir".into(),
|
||||
Some(vec!["md".into(), "txt".into()]),
|
||||
false,
|
||||
None
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(),
|
||||
(
|
||||
"C:\\dir".into(),
|
||||
Some(vec!["md".into(), "txt".into()]),
|
||||
false,
|
||||
None
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("*.md").unwrap(),
|
||||
(".".into(), Some(vec!["md".into()]), true, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("/*.md").unwrap(),
|
||||
("/".into(), Some(vec!["md".into()]), true, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/*.md").unwrap(),
|
||||
("dir".into(), Some(vec!["md".into()]), true, None)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("dir/*.{md,txt}").unwrap(),
|
||||
(
|
||||
"dir".into(),
|
||||
Some(vec!["md".into(), "txt".into()]),
|
||||
true,
|
||||
None
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_glob("C:\\dir\\*.{md,txt}").unwrap(),
|
||||
(
|
||||
"C:\\dir".into(),
|
||||
Some(vec!["md".into(), "txt".into()]),
|
||||
true,
|
||||
None
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Render REPL prompt
|
||||
///
|
||||
/// The template comprises plain text and `{...}`.
|
||||
///
|
||||
/// The syntax of `{...}`:
|
||||
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
|
||||
/// - `{?var <template>}` - Eval `template` when `var` is evaluated as true
|
||||
/// - `{!var <template>}` - Eval `template` when `var` is evaluated as false
|
||||
pub fn render_prompt(template: &str, variables: &HashMap<&str, String>) -> String {
|
||||
let exprs = parse_template(template);
|
||||
eval_exprs(&exprs, variables)
|
||||
}
|
||||
|
||||
fn parse_template(template: &str) -> Vec<Expr> {
|
||||
let chars: Vec<char> = template.chars().collect();
|
||||
let mut exprs = vec![];
|
||||
let mut current = vec![];
|
||||
let mut balances = vec![];
|
||||
for ch in chars.iter().cloned() {
|
||||
if !balances.is_empty() {
|
||||
if ch == '}' {
|
||||
balances.pop();
|
||||
if balances.is_empty() {
|
||||
if !current.is_empty() {
|
||||
let block = parse_block(&mut current);
|
||||
exprs.push(block)
|
||||
}
|
||||
} else {
|
||||
current.push(ch);
|
||||
}
|
||||
} else if ch == '{' {
|
||||
balances.push(ch);
|
||||
current.push(ch);
|
||||
} else {
|
||||
current.push(ch);
|
||||
}
|
||||
} else if ch == '{' {
|
||||
balances.push(ch);
|
||||
add_text(&mut exprs, &mut current);
|
||||
} else {
|
||||
current.push(ch)
|
||||
}
|
||||
}
|
||||
add_text(&mut exprs, &mut current);
|
||||
exprs
|
||||
}
|
||||
|
||||
fn parse_block(current: &mut Vec<char>) -> Expr {
|
||||
let value: String = current.drain(..).collect();
|
||||
match value.split_once(' ') {
|
||||
Some((name, tail)) => {
|
||||
if let Some(name) = name.strip_prefix('?') {
|
||||
let block_exprs = parse_template(tail);
|
||||
Expr::Block(BlockType::Yes, name.to_string(), block_exprs)
|
||||
} else if let Some(name) = name.strip_prefix('!') {
|
||||
let block_exprs = parse_template(tail);
|
||||
Expr::Block(BlockType::No, name.to_string(), block_exprs)
|
||||
} else {
|
||||
Expr::Text(format!("{{{value}}}"))
|
||||
}
|
||||
}
|
||||
None => Expr::Variable(value),
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_exprs(exprs: &[Expr], variables: &HashMap<&str, String>) -> String {
|
||||
let mut output = String::new();
|
||||
for part in exprs {
|
||||
match part {
|
||||
Expr::Text(text) => output.push_str(text),
|
||||
Expr::Variable(variable) => {
|
||||
let value = variables
|
||||
.get(variable.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
output.push_str(&value);
|
||||
}
|
||||
Expr::Block(typ, variable, block_exprs) => {
|
||||
let value = variables
|
||||
.get(variable.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
match typ {
|
||||
BlockType::Yes => {
|
||||
if truly(&value) {
|
||||
let block_output = eval_exprs(block_exprs, variables);
|
||||
output.push_str(&block_output)
|
||||
}
|
||||
}
|
||||
BlockType::No => {
|
||||
if !truly(&value) {
|
||||
let block_output = eval_exprs(block_exprs, variables);
|
||||
output.push_str(&block_output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn add_text(exprs: &mut Vec<Expr>, current: &mut Vec<char>) {
|
||||
if current.is_empty() {
|
||||
return;
|
||||
}
|
||||
let value: String = current.drain(..).collect();
|
||||
exprs.push(Expr::Text(value));
|
||||
}
|
||||
|
||||
fn truly(value: &str) -> bool {
|
||||
!(value.is_empty() || value == "0" || value == "false")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Expr {
|
||||
Text(String),
|
||||
Variable(String),
|
||||
Block(BlockType, String, Vec<Expr>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockType {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
macro_rules! assert_render {
|
||||
($template:expr, [$(($key:literal, $value:literal),)*], $expect:literal) => {
|
||||
let data = HashMap::from([
|
||||
$(($key, $value.into()),)*
|
||||
]);
|
||||
assert_eq!(render_prompt($template, &data), $expect);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render() {
|
||||
let prompt = "{?session {session}{?role /}}{role}{?session )}{!session >}";
|
||||
assert_render!(prompt, [], ">");
|
||||
assert_render!(prompt, [("role", "coder"),], "coder>");
|
||||
assert_render!(prompt, [("session", "temp"),], "temp)");
|
||||
assert_render!(
|
||||
prompt,
|
||||
[("session", "temp"), ("role", "coder"),],
|
||||
"temp/coder)"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,464 @@
|
||||
use super::*;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use fancy_regex::Regex;
|
||||
use futures_util::{stream, StreamExt};
|
||||
use http::header::CONTENT_TYPE;
|
||||
use reqwest::Url;
|
||||
use scraper::{Html, Selector};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::sync::LazyLock;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
pub const URL_LOADER: &str = "url";
|
||||
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
|
||||
|
||||
pub const MEDIA_URL_EXTENSION: &str = "media_url";
|
||||
pub const DEFAULT_EXTENSION: &str = "txt";
|
||||
|
||||
const MAX_CRAWLS: usize = 5;
|
||||
const BREAK_ON_ERROR: bool = false;
|
||||
const USER_AGENT: &str = "curl/8.6.0";
|
||||
|
||||
static CLIENT: LazyLock<Result<reqwest::Client>> = LazyLock::new(|| {
|
||||
let builder = reqwest::ClientBuilder::new().timeout(Duration::from_secs(16));
|
||||
let client = builder.build()?;
|
||||
Ok(client)
|
||||
});
|
||||
|
||||
static PRESET: LazyLock<Vec<(Regex, CrawlOptions)>> = LazyLock::new(|| {
|
||||
vec![
|
||||
(
|
||||
Regex::new(r"github.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap(),
|
||||
CrawlOptions {
|
||||
exclude: vec!["changelog".into(), "changes".into(), "license".into()],
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
(
|
||||
Regex::new(r"github.com/([^/]+)/([^/]+)/wiki").unwrap(),
|
||||
CrawlOptions {
|
||||
exclude: vec!["_history".into()],
|
||||
extract: Some("#wiki-body".into()),
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
]
|
||||
});
|
||||
|
||||
static EXTENSION_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\.[^.]+$").unwrap());
|
||||
static GITHUB_REPO_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"^https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap());
|
||||
|
||||
pub async fn fetch(url: &str) -> Result<String> {
|
||||
let client = match *CLIENT {
|
||||
Ok(ref client) => client,
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let res = client.get(url).send().await?;
|
||||
let output = res.text().await?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub async fn fetch_with_loaders(
|
||||
loaders: &HashMap<String, String>,
|
||||
path: &str,
|
||||
allow_media: bool,
|
||||
) -> Result<(String, String)> {
|
||||
if let Some(loader_command) = loaders.get(URL_LOADER) {
|
||||
let contents = run_loader_command(path, URL_LOADER, loader_command)?;
|
||||
return Ok((contents, DEFAULT_EXTENSION.into()));
|
||||
}
|
||||
let client = match *CLIENT {
|
||||
Ok(ref client) => client,
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let mut res = client.get(path).send().await?;
|
||||
if !res.status().is_success() {
|
||||
bail!("Invalid status: {}", res.status());
|
||||
}
|
||||
let content_type = res
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| match v.split_once(';') {
|
||||
Some((mime, _)) => mime.trim(),
|
||||
None => v,
|
||||
})
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
format!(
|
||||
"_/{}",
|
||||
get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into())
|
||||
)
|
||||
});
|
||||
let mut is_media = false;
|
||||
let extension = match content_type.as_str() {
|
||||
"application/pdf" => "pdf".into(),
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx".into(),
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx".into(),
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation" => {
|
||||
"pptx".into()
|
||||
}
|
||||
"application/vnd.oasis.opendocument.text" => "odt".into(),
|
||||
"application/vnd.oasis.opendocument.spreadsheet" => "ods".into(),
|
||||
"application/vnd.oasis.opendocument.presentation" => "odp".into(),
|
||||
"application/rtf" => "rtf".into(),
|
||||
"text/javascript" => "js".into(),
|
||||
"text/html" => "html".into(),
|
||||
_ => content_type
|
||||
.rsplit_once('/')
|
||||
.map(|(first, last)| {
|
||||
if ["image", "video", "audio"].contains(&first) {
|
||||
is_media = true;
|
||||
MEDIA_URL_EXTENSION.into()
|
||||
} else {
|
||||
last.to_lowercase()
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| DEFAULT_EXTENSION.into()),
|
||||
};
|
||||
let result = if is_media {
|
||||
if !allow_media {
|
||||
bail!("Unexpected media type")
|
||||
}
|
||||
let image_bytes = res.bytes().await?;
|
||||
let image_base64 = base64_encode(&image_bytes);
|
||||
let contents = format!("data:{content_type};base64,{image_base64}");
|
||||
(contents, extension)
|
||||
} else {
|
||||
match loaders.get(&extension) {
|
||||
Some(loader_command) => {
|
||||
let save_path = temp_file("-download-", &format!(".{extension}"))
|
||||
.display()
|
||||
.to_string();
|
||||
let mut save_file = tokio::fs::File::create(&save_path).await?;
|
||||
let mut size = 0;
|
||||
while let Some(chunk) = res.chunk().await? {
|
||||
size += chunk.len();
|
||||
save_file.write_all(&chunk).await?;
|
||||
}
|
||||
let contents = if size == 0 {
|
||||
println!("{}", warning_text(&format!("No content at '{path}'")));
|
||||
String::new()
|
||||
} else {
|
||||
run_loader_command(&save_path, &extension, loader_command)?
|
||||
};
|
||||
(contents, DEFAULT_EXTENSION.into())
|
||||
}
|
||||
None => {
|
||||
let contents = res.text().await?;
|
||||
if extension == "html" {
|
||||
(html_to_md(&contents), "md".into())
|
||||
} else {
|
||||
(contents, extension)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn fetch_models(api_base: &str, api_key: Option<&str>) -> Result<Vec<String>> {
|
||||
let client = match *CLIENT {
|
||||
Ok(ref client) => client,
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let mut builder = client.get(format!("{}/models", api_base.trim_end_matches('/')));
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
}
|
||||
let res_body: Value = builder.send().await?.json().await?;
|
||||
let mut result: Vec<String> = res_body
|
||||
.get("data")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.filter_map(|v| v.get("id").and_then(|v| v.as_str().map(|v| v.to_string())))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if result.is_empty() {
|
||||
bail!("No valid models")
|
||||
}
|
||||
result.sort_unstable();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CrawlOptions {
|
||||
extract: Option<String>,
|
||||
exclude: Vec<String>,
|
||||
no_log: bool,
|
||||
}
|
||||
|
||||
impl CrawlOptions {
|
||||
pub fn preset(start_url: &str) -> CrawlOptions {
|
||||
for (re, options) in PRESET.iter() {
|
||||
if let Ok(true) = re.is_match(start_url) {
|
||||
return options.clone();
|
||||
}
|
||||
}
|
||||
CrawlOptions::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn crawl_website(start_url: &str, options: CrawlOptions) -> Result<Vec<Page>> {
|
||||
let start_url = Url::parse(start_url)?;
|
||||
let mut paths = vec![start_url.path().to_string()];
|
||||
let normalized_start_url = normalize_start_url(&start_url);
|
||||
if !options.no_log {
|
||||
println!(
|
||||
"Start crawling url={start_url} exclude={} extract={}",
|
||||
options.exclude.join(","),
|
||||
options.extract.as_deref().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
|
||||
paths = crawl_gh_tree(&start_url, &options.exclude)
|
||||
.await
|
||||
.with_context(|| "Failed to craw github repo".to_string())?;
|
||||
}
|
||||
|
||||
let semaphore = Arc::new(Semaphore::new(MAX_CRAWLS));
|
||||
let mut result_pages = Vec::new();
|
||||
|
||||
let mut index = 0;
|
||||
while index < paths.len() {
|
||||
let batch = paths[index..std::cmp::min(index + MAX_CRAWLS, paths.len())].to_vec();
|
||||
|
||||
let tasks: Vec<_> = batch
|
||||
.iter()
|
||||
.map(|path| {
|
||||
let options = options.clone();
|
||||
let permit = semaphore.clone().acquire_owned(); // acquire a permit for concurrency control
|
||||
let normalized_start_url = normalized_start_url.clone();
|
||||
let path = path.clone();
|
||||
|
||||
async move {
|
||||
let _permit = permit.await?;
|
||||
let url = normalized_start_url
|
||||
.join(&path)
|
||||
.map_err(|_| anyhow!("Invalid crawl page at {}", path))?;
|
||||
let mut page = crawl_page(&normalized_start_url, &path, options)
|
||||
.await
|
||||
.with_context(|| format!("Failed to crawl {}", url.as_str()))?;
|
||||
page.0 = url.as_str().to_string();
|
||||
Ok(page)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = stream::iter(tasks)
|
||||
.buffer_unordered(MAX_CRAWLS)
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
let mut new_paths = Vec::new();
|
||||
|
||||
for res in results {
|
||||
match res {
|
||||
Ok((path, text, links)) => {
|
||||
if !options.no_log {
|
||||
println!("Crawled {path}");
|
||||
}
|
||||
if !text.is_empty() {
|
||||
result_pages.push(Page { path, text });
|
||||
}
|
||||
for link in links {
|
||||
if !paths.iter().any(|p| match_link(p, &link)) {
|
||||
new_paths.push(link);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
if BREAK_ON_ERROR {
|
||||
return Err(err);
|
||||
} else if !options.no_log {
|
||||
println!("{}", error_text(&pretty_error(&err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
paths.extend(new_paths);
|
||||
|
||||
index += batch.len();
|
||||
}
|
||||
|
||||
Ok(result_pages)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Page {
|
||||
pub path: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
async fn crawl_gh_tree(start_url: &Url, exclude: &[String]) -> Result<Vec<String>> {
|
||||
let path_segs: Vec<&str> = start_url.path().split('/').collect();
|
||||
if path_segs.len() < 4 {
|
||||
bail!("Invalid gh tree {}", start_url.as_str());
|
||||
}
|
||||
let client = match *CLIENT {
|
||||
Ok(ref client) => client,
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let owner = path_segs[1];
|
||||
let repo = path_segs[2];
|
||||
let branch = path_segs[4];
|
||||
let root_path = path_segs[5..].join("/");
|
||||
|
||||
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/ref/heads/{branch}");
|
||||
|
||||
let res_body: Value = client
|
||||
.get(&url)
|
||||
.header("User-Agent", USER_AGENT)
|
||||
.header("Accept", "application/vnd.github+json")
|
||||
.header("X-GitHub-Api-Version", "2022-11-28")
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let sha = res_body["object"]["sha"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Not found branch or tag"))?;
|
||||
|
||||
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}?recursive=true");
|
||||
|
||||
let res_body: Value = client
|
||||
.get(&url)
|
||||
.header("User-Agent", USER_AGENT)
|
||||
.header("Accept", "application/vnd.github+json")
|
||||
.header("X-GitHub-Api-Version", "2022-11-28")
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
let tree = res_body["tree"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow!("Invalid github repo tree"))?;
|
||||
let paths = tree
|
||||
.iter()
|
||||
.flat_map(|v| {
|
||||
let typ = v["type"].as_str()?;
|
||||
let path = v["path"].as_str()?;
|
||||
if typ == "blob"
|
||||
&& (path.ends_with(".md") || path.ends_with(".MD"))
|
||||
&& path.starts_with(&root_path)
|
||||
&& !should_exclude_link(path, exclude)
|
||||
{
|
||||
Some(format!(
|
||||
"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
async fn crawl_page(
|
||||
start_url: &Url,
|
||||
path: &str,
|
||||
options: CrawlOptions,
|
||||
) -> Result<(String, String, Vec<String>)> {
|
||||
let client = match *CLIENT {
|
||||
Ok(ref client) => client,
|
||||
Err(ref err) => bail!("{err}"),
|
||||
};
|
||||
let location = start_url.join(path)?;
|
||||
let response = client
|
||||
.get(location.as_str())
|
||||
.header("User-Agent", USER_AGENT)
|
||||
.send()
|
||||
.await?;
|
||||
let body = response.text().await?;
|
||||
|
||||
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
|
||||
return Ok((path.to_string(), body, vec![]));
|
||||
}
|
||||
|
||||
let mut links = HashSet::new();
|
||||
let document = Html::parse_document(&body);
|
||||
let selector = Selector::parse("a").map_err(|err| anyhow!("Invalid link selector, {}", err))?;
|
||||
|
||||
for element in document.select(&selector) {
|
||||
if let Some(href) = element.value().attr("href") {
|
||||
let href = Url::parse(href).ok().or_else(|| location.join(href).ok());
|
||||
match href {
|
||||
None => continue,
|
||||
Some(href) => {
|
||||
if href.as_str().starts_with(location.as_str())
|
||||
&& !should_exclude_link(href.path(), &options.exclude)
|
||||
{
|
||||
links.insert(href.path().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = if let Some(selector) = &options.extract {
|
||||
let selector = Selector::parse(selector)
|
||||
.map_err(|err| anyhow!("Invalid extract selector, {}", err))?;
|
||||
document
|
||||
.select(&selector)
|
||||
.map(|v| html_to_md(&v.html()))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n\n")
|
||||
} else {
|
||||
html_to_md(&body)
|
||||
};
|
||||
|
||||
Ok((path.to_string(), text, links.into_iter().collect()))
|
||||
}
|
||||
|
||||
fn should_exclude_link(link: &str, exclude: &[String]) -> bool {
|
||||
if link.contains("#") {
|
||||
return true;
|
||||
}
|
||||
let parts: Vec<&str> = link.trim_end_matches('/').split('/').collect();
|
||||
let name = parts.last().unwrap_or(&"").to_lowercase();
|
||||
|
||||
for exclude_name in exclude {
|
||||
let cond = match EXTENSION_RE.is_match(exclude_name) {
|
||||
Ok(true) => exclude_name.to_lowercase() == name.to_lowercase(),
|
||||
_ => exclude_name.to_lowercase() == EXTENSION_RE.replace(&name, "").to_lowercase(),
|
||||
};
|
||||
if cond {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn normalize_start_url(start_url: &Url) -> Url {
|
||||
let mut start_url = start_url.clone();
|
||||
start_url.set_query(None);
|
||||
start_url.set_fragment(None);
|
||||
let new_path = match start_url.path().rfind('/') {
|
||||
Some(last_slash_index) => start_url.path()[..last_slash_index + 1].to_string(),
|
||||
None => start_url.path().to_string(),
|
||||
};
|
||||
start_url.set_path(&new_path);
|
||||
start_url
|
||||
}
|
||||
|
||||
fn match_link(path: &str, link: &str) -> bool {
|
||||
path == link
|
||||
|| path
|
||||
== link
|
||||
.trim_end_matches("/index.html")
|
||||
.trim_end_matches("/index.htm")
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
use super::{poll_abort_signal, wait_abort_signal, AbortSignal, IS_STDOUT_TERMINAL};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use crossterm::{cursor, queue, style, terminal};
|
||||
use std::{
|
||||
future::Future,
|
||||
io::{stdout, Write},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
time::interval,
|
||||
};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SpinnerInner {
|
||||
index: usize,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl SpinnerInner {
|
||||
const DATA: [&'static str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
|
||||
|
||||
fn step(&mut self) -> Result<()> {
|
||||
if !*IS_STDOUT_TERMINAL || self.message.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut writer = stdout();
|
||||
let frame = Self::DATA[self.index % Self::DATA.len()];
|
||||
let dots = ".".repeat((self.index / 5) % 4);
|
||||
let line = format!("{frame}{}{:<3}", self.message, dots);
|
||||
queue!(writer, cursor::MoveToColumn(0), style::Print(line),)?;
|
||||
if self.index == 0 {
|
||||
queue!(writer, cursor::Hide)?;
|
||||
}
|
||||
writer.flush()?;
|
||||
self.index += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_message(&mut self, message: String) -> Result<()> {
|
||||
self.clear_message()?;
|
||||
if !message.is_empty() {
|
||||
self.message = format!(" {message}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_message(&mut self) -> Result<()> {
|
||||
if !*IS_STDOUT_TERMINAL || self.message.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.message.clear();
|
||||
let mut writer = stdout();
|
||||
queue!(
|
||||
writer,
|
||||
cursor::MoveToColumn(0),
|
||||
terminal::Clear(terminal::ClearType::FromCursorDown),
|
||||
cursor::Show
|
||||
)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Spinner(mpsc::UnboundedSender<SpinnerEvent>);
|
||||
|
||||
impl Spinner {
|
||||
pub fn create(message: &str) -> (Self, UnboundedReceiver<SpinnerEvent>) {
|
||||
let (tx, spinner_rx) = mpsc::unbounded_channel();
|
||||
let spinner = Spinner(tx);
|
||||
let _ = spinner.set_message(message.to_string());
|
||||
(spinner, spinner_rx)
|
||||
}
|
||||
|
||||
pub fn set_message(&self, message: String) -> Result<()> {
|
||||
self.0.send(SpinnerEvent::SetMessage(message))?;
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
let _ = self.0.send(SpinnerEvent::Stop);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
}
|
||||
}
|
||||
|
||||
pub enum SpinnerEvent {
|
||||
SetMessage(String),
|
||||
Stop,
|
||||
}
|
||||
|
||||
pub fn spawn_spinner(message: &str) -> Spinner {
|
||||
let (spinner, mut spinner_rx) = Spinner::create(message);
|
||||
tokio::spawn(async move {
|
||||
let mut spinner = SpinnerInner::default();
|
||||
let mut interval = interval(Duration::from_millis(50));
|
||||
loop {
|
||||
tokio::select! {
|
||||
evt = spinner_rx.recv() => {
|
||||
if let Some(evt) = evt {
|
||||
match evt {
|
||||
SpinnerEvent::SetMessage(message) => {
|
||||
spinner.set_message(message)?;
|
||||
}
|
||||
SpinnerEvent::Stop => {
|
||||
spinner.clear_message()?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
let _ = spinner.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<(), anyhow::Error>(())
|
||||
});
|
||||
spinner
|
||||
}
|
||||
|
||||
pub async fn abortable_run_with_spinner<F, T>(
|
||||
task: F,
|
||||
message: &str,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<T>
|
||||
where
|
||||
F: Future<Output = Result<T>>,
|
||||
{
|
||||
let (_, spinner_rx) = Spinner::create(message);
|
||||
abortable_run_with_spinner_rx(task, spinner_rx, abort_signal).await
|
||||
}
|
||||
|
||||
pub async fn abortable_run_with_spinner_rx<F, T>(
|
||||
task: F,
|
||||
spinner_rx: UnboundedReceiver<SpinnerEvent>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<T>
|
||||
where
|
||||
F: Future<Output = Result<T>>,
|
||||
{
|
||||
if *IS_STDOUT_TERMINAL {
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
let run_task = async {
|
||||
tokio::select! {
|
||||
ret = task => {
|
||||
let _ = done_tx.send(());
|
||||
ret
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
abort_signal.set_ctrlc();
|
||||
let _ = done_tx.send(());
|
||||
bail!("Aborted!")
|
||||
},
|
||||
_ = wait_abort_signal(&abort_signal) => {
|
||||
let _ = done_tx.send(());
|
||||
bail!("Aborted.");
|
||||
},
|
||||
}
|
||||
};
|
||||
let (task_ret, spinner_ret) = tokio::join!(
|
||||
run_task,
|
||||
run_abortable_spinner(spinner_rx, done_rx, abort_signal.clone())
|
||||
);
|
||||
spinner_ret?;
|
||||
task_ret
|
||||
} else {
|
||||
task.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_abortable_spinner(
|
||||
mut spinner_rx: UnboundedReceiver<SpinnerEvent>,
|
||||
mut done_rx: oneshot::Receiver<()>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let mut spinner = SpinnerInner::default();
|
||||
loop {
|
||||
if abort_signal.aborted() {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
match done_rx.try_recv() {
|
||||
Ok(_) | Err(oneshot::error::TryRecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
match spinner_rx.try_recv() {
|
||||
Ok(SpinnerEvent::SetMessage(message)) => {
|
||||
spinner.set_message(message)?;
|
||||
}
|
||||
Ok(SpinnerEvent::Stop) => {
|
||||
spinner.clear_message()?;
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
|
||||
if poll_abort_signal(&abort_signal)? {
|
||||
break;
|
||||
}
|
||||
|
||||
spinner.step()?;
|
||||
}
|
||||
|
||||
spinner.clear_message()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
use super::*;
|
||||
use fancy_regex::{Captures, Regex};
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub static RE_VARIABLE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{(\w+)\}\}").unwrap());
|
||||
pub fn interpolate_variables(text: &mut String) {
|
||||
*text = RE_VARIABLE
|
||||
.replace_all(text, |caps: &Captures<'_>| {
|
||||
let key = &caps[1];
|
||||
match key {
|
||||
"__os__" => env::consts::OS.to_string(),
|
||||
"__os_distro__" => {
|
||||
let info = os_info::get();
|
||||
if env::consts::OS == "linux" {
|
||||
format!("{info} (linux)")
|
||||
} else {
|
||||
info.to_string()
|
||||
}
|
||||
}
|
||||
"__os_family__" => env::consts::FAMILY.to_string(),
|
||||
"__arch__" => env::consts::ARCH.to_string(),
|
||||
"__shell__" => SHELL.name.clone(),
|
||||
"__locale__" => sys_locale::get_locale().unwrap_or_default(),
|
||||
"__now__" => now(),
|
||||
"__cwd__" => env::current_dir()
|
||||
.map(|v| v.display().to_string())
|
||||
.unwrap_or_default(),
|
||||
_ => format!("{{{{{key}}}}}"),
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
}
|
||||
Reference in New Issue
Block a user