Baseline project

This commit is contained in:
2025-10-07 10:45:42 -06:00
parent 88288a98b6
commit 650dbd92e0
54 changed files with 18982 additions and 0 deletions
+219
View File
@@ -0,0 +1,219 @@
use crate::client::{list_models, ModelType};
use crate::config::{list_agents, Config};
use anyhow::{Context, Result};
use clap::ValueHint;
use clap::{crate_authors, crate_description, crate_name, crate_version, Parser};
use clap_complete::ArgValueCompleter;
use clap_complete::CompletionCandidate;
use is_terminal::IsTerminal;
use std::ffi::OsStr;
use std::io::{stdin, Read};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[command(
name = crate_name!(),
author = crate_authors!(),
version = crate_version!(),
about = crate_description!(),
help_template = "\
{before-help}{name} {version}
{author-with-newline}
{about-with-newline}
{usage-heading} {usage}
{all-args}{after-help}
"
)]
pub struct Cli {
/// Select a LLM model
#[arg(short, long, add = ArgValueCompleter::new(model_completer))]
pub model: Option<String>,
/// Use the system prompt
#[arg(long)]
pub prompt: Option<String>,
/// Select a role
#[arg(short, long, add = ArgValueCompleter::new(role_completer))]
pub role: Option<String>,
/// Start or join a session
#[arg(short = 's', long, add = ArgValueCompleter::new(session_completer))]
pub session: Option<Option<String>>,
/// Ensure the session is empty
#[arg(long)]
pub empty_session: bool,
/// Ensure the new conversation is saved to the session
#[arg(long)]
pub save_session: bool,
/// Start an agent
#[arg(short = 'a', long, add = ArgValueCompleter::new(agent_completer))]
pub agent: Option<String>,
/// Set agent variables
#[arg(long, value_names = ["NAME", "VALUE"], num_args = 2)]
pub agent_variable: Vec<String>,
/// Start a RAG
#[arg(long, add = ArgValueCompleter::new(rag_completer))]
pub rag: Option<String>,
/// Rebuild the RAG to sync document changes
#[arg(long)]
pub rebuild_rag: bool,
/// Execute a macro
#[arg(long = "macro", value_name = "MACRO", add = ArgValueCompleter::new(macro_completer))]
pub macro_name: Option<String>,
/// Serve the LLM API and WebAPP
#[arg(long, value_name = "PORT|IP|IP:PORT")]
pub serve: Option<Option<String>>,
/// Execute commands in natural language
#[arg(short = 'e', long)]
pub execute: bool,
/// Output code only
#[arg(short = 'c', long)]
pub code: bool,
/// Include files, directories, or URLs
#[arg(short = 'f', long, value_name = "FILE|URL", value_hint = ValueHint::AnyPath)]
pub file: Vec<String>,
/// Turn off stream mode
#[arg(short = 'S', long)]
pub no_stream: bool,
/// Display the message without sending it
#[arg(long)]
pub dry_run: bool,
/// Display information
#[arg(long)]
pub info: bool,
/// Build all configured Bash tool scripts
#[arg(long)]
pub build_tools: bool,
/// Sync models updates
#[arg(long)]
pub sync_models: bool,
/// List all available chat models
#[arg(long)]
pub list_models: bool,
/// List all roles
#[arg(long)]
pub list_roles: bool,
/// List all sessions
#[arg(long)]
pub list_sessions: bool,
/// List all agents
#[arg(long)]
pub list_agents: bool,
/// List all RAGs
#[arg(long)]
pub list_rags: bool,
/// List all macros
#[arg(long)]
pub list_macros: bool,
/// Input text
#[arg(trailing_var_arg = true)]
text: Vec<String>,
/// Tail logs
#[arg(long)]
pub tail_logs: bool,
/// Disable colored log output
#[arg(long, requires = "tail_logs")]
pub disable_log_colors: bool,
}
impl Cli {
pub fn text(&self) -> Result<Option<String>> {
let mut stdin_text = String::new();
if !stdin().is_terminal() {
let _ = stdin()
.read_to_string(&mut stdin_text)
.context("Invalid stdin pipe")?;
};
match self.text.is_empty() {
true => {
if stdin_text.is_empty() {
Ok(None)
} else {
Ok(Some(stdin_text))
}
}
false => {
if self.macro_name.is_some() {
let text = self
.text
.iter()
.map(|v| shell_words::quote(v))
.collect::<Vec<_>>()
.join(" ");
if stdin_text.is_empty() {
Ok(Some(text))
} else {
Ok(Some(format!("{text} -- {stdin_text}")))
}
} else {
let text = self.text.join(" ");
if stdin_text.is_empty() {
Ok(Some(text))
} else {
Ok(Some(format!("{text}\n{stdin_text}")))
}
}
}
}
}
}
fn model_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
match Config::init_bare() {
Ok(config) => list_models(&config, ModelType::Chat)
.into_iter()
.filter(|&m| m.id().starts_with(&*cur))
.map(|m| CompletionCandidate::new(m.id()))
.collect(),
Err(_) => vec![],
}
}
fn role_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
Config::list_roles(true)
.into_iter()
.filter(|r| r.starts_with(&*cur))
.map(CompletionCandidate::new)
.collect()
}
fn agent_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
list_agents()
.into_iter()
.filter(|a| a.starts_with(&*cur))
.map(CompletionCandidate::new)
.collect()
}
fn rag_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
Config::list_rags()
.into_iter()
.filter(|r| r.starts_with(&*cur))
.map(CompletionCandidate::new)
.collect()
}
fn macro_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
Config::list_macros()
.into_iter()
.filter(|m| m.starts_with(&*cur))
.map(CompletionCandidate::new)
.collect()
}
fn session_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
match Config::init_bare() {
Ok(config) => config
.list_sessions()
.into_iter()
.filter(|s| s.starts_with(&*cur))
.map(CompletionCandidate::new)
.collect(),
Err(_) => vec![],
}
}
+32
View File
@@ -0,0 +1,32 @@
use anyhow::{anyhow, Result};
use chrono::Utc;
use indexmap::IndexMap;
use parking_lot::RwLock;
use std::sync::LazyLock;
static ACCESS_TOKENS: LazyLock<RwLock<IndexMap<String, (String, i64)>>> =
LazyLock::new(|| RwLock::new(IndexMap::new()));
pub fn get_access_token(client_name: &str) -> Result<String> {
ACCESS_TOKENS
.read()
.get(client_name)
.map(|(token, _)| token.clone())
.ok_or_else(|| anyhow!("Invalid access token"))
}
pub fn is_valid_access_token(client_name: &str) -> bool {
let access_tokens = ACCESS_TOKENS.read();
let (token, expires_at) = match access_tokens.get(client_name) {
Some(v) => v,
None => return false,
};
!token.is_empty() && Utc::now().timestamp() < *expires_at
}
pub fn set_access_token(client_name: &str, token: String, expires_at: i64) {
let mut access_tokens = ACCESS_TOKENS.write();
let entry = access_tokens.entry(client_name.to_string()).or_default();
entry.0 = token;
entry.1 = expires_at;
}
+82
View File
@@ -0,0 +1,82 @@
use super::openai::*;
use super::*;
use anyhow::Result;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 2] = [
(
"api_base",
"API Base",
Some("e.g. https://{RESOURCE}.openai.azure.com"),
),
("api_key", "API Key", None),
];
}
impl_client_trait!(
AzureOpenAIClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &AzureOpenAIClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-12-01-preview",
&api_base,
self_.model.real_name()
);
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.header("api-key", api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;
let url = format!(
"{}/openai/deployments/{}/embeddings?api-version=2024-10-21",
&api_base,
self_.model.real_name()
);
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.header("api-key", api_key);
Ok(request_data)
}
+643
View File
@@ -0,0 +1,643 @@
use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
use anyhow::{bail, Context, Result};
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
use chrono::{DateTime, Utc};
use futures_util::StreamExt;
use indexmap::IndexMap;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
pub name: Option<String>,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub region: Option<String>,
pub session_token: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl BedrockClient {
config_get_fn!(access_key_id, get_access_key_id);
config_get_fn!(secret_access_key, get_secret_access_key);
config_get_fn!(region, get_region);
config_get_fn!(session_token, get_session_token);
pub const PROMPTS: [PromptAction<'static>; 3] = [
("access_key_id", "AWS Access Key ID", None),
("secret_access_key", "AWS Secret Access Key", None),
("region", "AWS Region", None),
];
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let session_token = self.get_session_token().ok();
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let model_name = &self.model.real_name();
let uri = if data.stream {
format!("/model/{model_name}/converse-stream")
} else {
format!("/model/{model_name}/converse")
};
let body = build_chat_completions_body(data, &self.model)?;
let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
body,
} = request_data;
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
session_token,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
let region = self.get_region()?;
let session_token = self.get_session_token().ok();
let host = format!("bedrock-runtime.{region}.amazonaws.com");
let uri = format!("/model/{}/invoke", self.model.real_name());
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"texts": data.texts,
"input_type": input_type,
});
let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data);
let RequestData {
url: _,
headers,
body,
} = request_data;
let builder = aws_fetch(
client,
&AwsCredentials {
access_key_id,
secret_access_key,
region,
session_token,
},
AwsRequest {
method: Method::POST,
host,
service: "bedrock".into(),
uri,
querystring: "".into(),
headers,
body: body.to_string(),
},
)?;
Ok(builder)
}
}
#[async_trait::async_trait]
impl Client for BedrockClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let builder = self.embeddings_builder(client, data)?;
embeddings(builder).await
}
}
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_chat_completions(&data)
}
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
bail!("Invalid response data: {data}");
}
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let mut stream = res.bytes_stream();
let mut buffer = BytesMut::new();
let mut decoder = MessageFrameDecoder::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buffer.extend_from_slice(&chunk);
while let DecodedFrame::Complete(message) = decoder.decode_frame(&mut buffer)? {
let response_headers = parse_response_headers(&message)?;
let message_type = response_headers.message_type.as_str();
let smithy_type = response_headers.smithy_type.as_str();
match (message_type, smithy_type) {
("event", _) => {
let data: Value = serde_json::from_slice(message.payload())?;
debug!("stream-data: {smithy_type} {data}");
match smithy_type {
"contentBlockStart" => {
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
if let (Some(id), Some(name)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
) {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_arguments.clear();
function_name = name.into();
function_id = id.into();
}
}
}
"contentBlockDelta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let Some(text) =
data["delta"]["reasoningContent"]["text"].as_str()
{
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
} else if let Some(input) = data["delta"]["toolUse"]["input"].as_str() {
function_arguments.push_str(input);
}
}
"contentBlockStop" => {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
("exception", _) => {
let payload = base64_decode(message.payload())?;
let data = String::from_utf8_lossy(&payload);
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
}
_ => {
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
}
}
}
}
Ok(())
}
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<Vec<f32>>,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream: _,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": [ { "text": strip_think_tag(&text) } ] })]
}
MessageContent::Text(text) => vec![json!({
"role": role,
"content": [
{
"text": text,
}
],
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"text": text})
}
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"image": {
"format": mime_type.replace("image/", ""),
"source": {
"bytes": data,
}
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
assistant_parts.push(json!({
"text": text,
}))
}
for tool_result in tool_results {
assistant_parts.push(json!({
"toolUse": {
"toolUseId": tool_result.call.id,
"name": tool_result.call.name,
"input": tool_result.call.arguments,
}
}));
user_parts.push(json!({
"toolResult": {
"toolUseId": tool_result.call.id,
"content": [
{
"json": tool_result.output,
}
]
}
}));
}
vec![
json!({
"role": "assistant",
"content": assistant_parts,
}),
json!({
"role": "user",
"content": user_parts,
}),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"inferenceConfig": {},
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = json!([
{
"text": v,
}
])
}
if let Some(v) = model.max_tokens_param() {
body["inferenceConfig"]["maxTokens"] = v.into();
}
if let Some(v) = temperature {
body["inferenceConfig"]["temperature"] = v.into();
}
if let Some(v) = top_p {
body["inferenceConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
let tools: Vec<_> = functions
.iter()
.map(|v| {
json!({
"toolSpec": {
"name": v.name,
"description": v.description,
"inputSchema": {
"json": v.parameters,
},
}
})
})
.collect();
body["toolConfig"] = json!({
"tools": tools,
})
}
Ok(body)
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = String::new();
let mut reasoning = None;
let mut tool_calls = vec![];
if let Some(array) = data["output"]["message"]["content"].as_array() {
for item in array {
if let Some(v) = item["text"].as_str() {
if !text.is_empty() {
text.push_str("\n\n");
}
text.push_str(v);
} else if let Some(reasoning_text) =
item["reasoningContent"]["reasoningText"].as_object()
{
if let Some(text) = json_str_from_map(reasoning_text, "text") {
reasoning = Some(text.to_string());
}
} else if let Some(tool_use) = item["toolUse"].as_object() {
if let (Some(id), Some(name), Some(input)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
tool_use.get("input"),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
}
}
}
}
if let Some(reasoning) = reasoning {
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: None,
input_tokens: data["usage"]["inputTokens"].as_u64(),
output_tokens: data["usage"]["outputTokens"].as_u64(),
};
Ok(output)
}
#[derive(Debug)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
region: String,
session_token: Option<String>,
}
#[derive(Debug)]
struct AwsRequest {
method: Method,
host: String,
service: String,
uri: String,
querystring: String,
headers: IndexMap<String, String>,
body: String,
}
fn aws_fetch(
client: &ReqwestClient,
credentials: &AwsCredentials,
request: AwsRequest,
) -> Result<RequestBuilder> {
let AwsRequest {
method,
host,
service,
uri,
querystring,
mut headers,
body,
} = request;
let region = &credentials.region;
let endpoint = format!("https://{host}{uri}");
let now: DateTime<Utc> = Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = amz_date[0..8].to_string();
headers.insert("host".into(), host.clone());
headers.insert("x-amz-date".into(), amz_date.clone());
if let Some(token) = credentials.session_token.clone() {
headers.insert("x-amz-security-token".into(), token);
}
let canonical_headers = headers
.iter()
.map(|(key, value)| format!("{key}:{value}\n"))
.collect::<Vec<_>>()
.join("");
let signed_headers = headers
.iter()
.map(|(key, _)| key.as_str())
.collect::<Vec<_>>()
.join(";");
let payload_hash = sha256(&body);
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
encode_uri(&uri),
querystring,
canonical_headers,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request");
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm,
amz_date,
credential_scope,
sha256(&canonical_request)
);
let signing_key = gen_signing_key(
&credentials.secret_access_key,
&date_stamp,
region,
&service,
);
let signature = hmac_sha256(&signing_key, &string_to_sign);
let signature = hex_encode(&signature);
let authorization_header = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, credentials.access_key_id, credential_scope, signed_headers, signature
);
headers.insert("authorization".into(), authorization_header);
debug!("Request {endpoint} {body}");
let mut request_builder = client.request(method, endpoint).body(body);
for (key, value) in &headers {
request_builder = request_builder.header(key, value);
}
Ok(request_builder)
}
fn gen_signing_key(key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp);
let k_region = hmac_sha256(&k_date, region);
let k_service = hmac_sha256(&k_region, service);
hmac_sha256(&k_service, "aws4_request")
}
+353
View File
@@ -0,0 +1,353 @@
use super::*;
use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.anthropic.com/v1";
#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
ClaudeClient,
(
prepare_chat_completions,
claude_chat_completions,
claude_chat_completions_streaming
),
(noop_prepare_embeddings, noop_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &ClaudeClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/messages", api_base.trim_end_matches('/'));
let body = claude_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("anthropic-version", "2023-06-01");
request_data.header("x-api-key", api_key);
Ok(request_data)
}
pub async fn claude_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
claude_extract_chat_completions(&data)
}
pub async fn claude_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let handle = |message: SseMessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
match typ {
"content_block_start" => {
if let (Some("tool_use"), Some(name), Some(id)) = (
data["content_block"]["type"].as_str(),
data["content_block"]["name"].as_str(),
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name = name.into();
function_arguments.clear();
function_id = id.into();
}
}
"content_block_delta" => {
if let Some(text) = data["delta"]["text"].as_str() {
handler.text(text)?;
} else if let Some(text) = data["delta"]["thinking"].as_str() {
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
} else if let (true, Some(partial_json)) = (
!function_name.is_empty(),
data["delta"]["partial_json"].as_str(),
) {
function_arguments.push_str(partial_json);
}
}
"content_block_stop" => {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
if !function_name.is_empty() {
let arguments: Value = if function_arguments.is_empty() {
json!({})
} else {
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?
};
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
}
_ => {}
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
pub fn claude_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": strip_think_tag(&text) })]
}
MessageContent::Text(text) => vec![json!({
"role": role,
"content": text,
})],
MessageContent::Array(list) => {
let content: Vec<_> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => {
json!({"type": "text", "text": text})
}
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
if let Some((mime_type, data)) = url
.strip_prefix("data:")
.and_then(|v| v.split_once(";base64,"))
{
json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": data,
}
})
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
}
})
.collect();
vec![json!({
"role": role,
"content": content,
})]
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
assistant_parts.push(json!({
"type": "text",
"text": text,
}))
}
for tool_result in tool_results {
assistant_parts.push(json!({
"type": "tool_use",
"id": tool_result.call.id,
"name": tool_result.call.name,
"input": tool_result.call.arguments,
}));
user_parts.push(json!({
"type": "tool_result",
"tool_use_id": tool_result.call.id,
"content": tool_result.output.to_string(),
}));
}
vec![
json!({
"role": "assistant",
"content": assistant_parts,
}),
json!({
"role": "user",
"content": user_parts,
}),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({
"model": model.real_name(),
"messages": messages,
});
if let Some(v) = system_message {
body["system"] = v.into();
}
if let Some(v) = model.max_tokens_param() {
body["max_tokens"] = v.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"name": v.name,
"description": v.description,
"input_schema": v.parameters,
})
})
.collect();
}
Ok(body)
}
pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = String::new();
let mut reasoning = None;
let mut tool_calls = vec![];
if let Some(list) = data["content"].as_array() {
for item in list {
match item["type"].as_str() {
Some("thinking") => {
if let Some(v) = item["thinking"].as_str() {
reasoning = Some(v.to_string());
}
}
Some("text") => {
if let Some(v) = item["text"].as_str() {
if !text.is_empty() {
text.push_str("\n\n");
}
text.push_str(v);
}
}
Some("tool_use") => {
if let (Some(name), Some(input), Some(id)) = (
item["name"].as_str(),
item.get("input"),
item["id"].as_str(),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
));
}
}
_ => {}
}
}
}
if let Some(reasoning) = reasoning {
text = format!("<think>\n{reasoning}\n</think>\n\n{text}")
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["output_tokens"].as_u64(),
};
Ok(output)
}
+255
View File
@@ -0,0 +1,255 @@
use super::openai::*;
use super::openai_compatible::*;
use super::*;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.cohere.ai/v2";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl CohereClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
CohereClient,
(
prepare_chat_completions,
chat_completions,
chat_completions_streaming
),
(prepare_embeddings, embeddings),
(prepare_rerank, generic_rerank),
);
fn prepare_chat_completions(
self_: &CohereClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/chat", api_base.trim_end_matches('/'));
let mut body = openai_build_chat_completions_body(data, &self_.model);
if let Some(obj) = body.as_object_mut() {
if let Some(top_p) = obj.remove("top_p") {
obj.insert("p".to_string(), top_p);
}
}
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/embed", api_base.trim_end_matches('/'));
let input_type = match data.query {
true => "search_query",
false => "search_document",
};
let body = json!({
"model": self_.model.real_name(),
"texts": data.texts,
"input_type": input_type,
"embedding_types": ["float"],
});
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/rerank", api_base.trim_end_matches('/'));
let body = generic_build_rerank_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
async fn chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
extract_chat_completions(&data)
}
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let handle = |message: SseMessage| -> Result<bool> {
if message.data == "[DONE]" {
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(typ) = data["type"].as_str() {
match typ {
"content-delta" => {
if let Some(text) = data["delta"]["message"]["content"]["text"].as_str() {
handler.text(text)?;
}
}
"tool-plan-delta" => {
if let Some(text) = data["delta"]["message"]["tool_plan"].as_str() {
handler.text(text)?;
}
}
"tool-call-start" => {
if let (Some(function), Some(id)) = (
data["delta"]["message"]["tool_calls"]["function"].as_object(),
data["delta"]["message"]["tool_calls"]["id"].as_str(),
) {
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
function_name = name.to_string();
}
function_id = id.to_string();
}
}
"tool-call-delta" => {
if let Some(text) =
data["delta"]["message"]["tool_calls"]["function"]["arguments"].as_str()
{
function_arguments.push_str(text);
}
}
"tool-call-end" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_name.clear();
function_arguments.clear();
function_id.clear();
}
_ => {}
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.embeddings.float)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: EmbeddingsResBodyEmbeddings,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbeddings {
float: Vec<Vec<f32>>,
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text = data["message"]["content"][0]["text"]
.as_str()
.unwrap_or_default()
.to_string();
let mut tool_calls = vec![];
if let Some(calls) = data["message"]["tool_calls"].as_array() {
if text.is_empty() {
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
text = tool_plain.to_string();
}
}
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
arguments,
Some(id.to_string()),
));
}
}
}
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["billed_units"]["input_tokens"].as_u64(),
output_tokens: data["usage"]["billed_units"]["output_tokens"].as_u64(),
};
Ok(output)
}
+678
View File
@@ -0,0 +1,678 @@
use super::*;
use crate::{
config::{Config, GlobalConfig, Input},
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
render::render_stream,
utils::*,
};
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use indexmap::IndexMap;
use inquire::{
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::mpsc::unbounded_channel;
const MODELS_YAML: &str = include_str!("../../models.yaml");
pub static ALL_PROVIDER_MODELS: LazyLock<Vec<ProviderModels>> = LazyLock::new(|| {
Config::local_models_override()
.ok()
.unwrap_or_else(|| serde_yaml::from_str(MODELS_YAML).unwrap())
});
static EMBEDDING_MODEL_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"((^|/)(bge-|e5-|uae-|gte-|text-)|embed|multilingual|minilm)").unwrap()
});
static ESCAPE_SLASH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?<!\\)/").unwrap());
#[async_trait::async_trait]
pub trait Client: Sync + Send {
fn global_config(&self) -> &GlobalConfig;
fn extra_config(&self) -> Option<&ExtraConfig>;
fn patch_config(&self) -> Option<&RequestPatch>;
fn name(&self) -> &str;
fn model(&self) -> &Model;
fn model_mut(&mut self) -> &mut Model;
fn build_client(&self) -> Result<ReqwestClient> {
let mut builder = ReqwestClient::builder();
let extra = self.extra_config();
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
if let Some(proxy) = extra.and_then(|v| v.proxy.as_deref()) {
builder = set_proxy(builder, proxy)?;
}
if let Some(user_agent) = self.global_config().read().user_agent.as_ref() {
builder = builder.user_agent(user_agent);
}
let client = builder
.connect_timeout(Duration::from_secs(timeout))
.build()
.with_context(|| "Failed to build client")?;
Ok(client)
}
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
if self.global_config().read().dry_run {
let content = input.echo_messages();
return Ok(ChatCompletionsOutput::new(&content));
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), false)?;
self.chat_completions_inner(&client, data)
.await
.with_context(|| "Failed to call chat-completions api")
}
async fn chat_completions_streaming(
&self,
input: &Input,
handler: &mut SseHandler,
) -> Result<()> {
let abort_signal = handler.abort();
let input = input.clone();
tokio::select! {
ret = async {
if self.global_config().read().dry_run {
let content = input.echo_messages();
handler.text(&content)?;
return Ok(());
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), true)?;
self.chat_completions_streaming_inner(&client, handler, data).await
} => {
handler.done();
ret.with_context(|| "Failed to call chat-completions api")
}
_ = wait_abort_signal(&abort_signal) => {
handler.done();
Ok(())
},
}
}
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
let client = self.build_client()?;
self.embeddings_inner(&client, data)
.await
.context("Failed to call embeddings api")
}
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
let client = self.build_client()?;
self.rerank_inner(&client, data)
.await
.context("Failed to call rerank api")
}
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput>;
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()>;
async fn embeddings_inner(
&self,
_client: &ReqwestClient,
_data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
async fn rerank_inner(
&self,
_client: &ReqwestClient,
_data: &RerankData,
) -> Result<RerankOutput> {
bail!("The client doesn't support rerank api")
}
fn request_builder(
&self,
client: &reqwest::Client,
mut request_data: RequestData,
) -> RequestBuilder {
self.patch_request_data(&mut request_data);
request_data.into_builder(client)
}
fn patch_request_data(&self, request_data: &mut RequestData) {
let model_type = self.model().model_type();
if let Some(patch) = self.model().patch() {
request_data.apply_patch(patch.clone());
}
let patch_map = std::env::var(get_env_name(&format!(
"patch_{}_{}",
self.model().client_name(),
model_type.api_name(),
)))
.ok()
.and_then(|v| serde_json::from_str(&v).ok())
.or_else(|| {
self.patch_config()
.and_then(|v| model_type.extract_patch(v))
.cloned()
});
let patch_map = match patch_map {
Some(v) => v,
_ => return,
};
for (key, patch) in patch_map {
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
if let Ok(true) = regex.is_match(self.model().name()) {
request_data.apply_patch(patch);
return;
}
}
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self::OpenAIConfig(OpenAIConfig::default())
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtraConfig {
pub proxy: Option<String>,
pub connect_timeout: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct RequestPatch {
pub chat_completions: Option<ApiPatch>,
pub embeddings: Option<ApiPatch>,
pub rerank: Option<ApiPatch>,
}
pub type ApiPatch = IndexMap<String, Value>;
pub struct RequestData {
pub url: String,
pub headers: IndexMap<String, String>,
pub body: Value,
}
impl RequestData {
pub fn new<T>(url: T, body: Value) -> Self
where
T: std::fmt::Display,
{
Self {
url: url.to_string(),
headers: Default::default(),
body,
}
}
pub fn bearer_auth<T>(&mut self, auth: T)
where
T: std::fmt::Display,
{
self.headers
.insert("authorization".into(), format!("Bearer {auth}"));
}
pub fn header<K, V>(&mut self, key: K, value: V)
where
K: std::fmt::Display,
V: std::fmt::Display,
{
self.headers.insert(key.to_string(), value.to_string());
}
pub fn into_builder(self, client: &ReqwestClient) -> RequestBuilder {
let RequestData { url, headers, body } = self;
debug!("Request {url} {body}");
let mut builder = client.post(url);
for (key, value) in headers {
builder = builder.header(key, value);
}
builder = builder.json(&body);
builder
}
pub fn apply_patch(&mut self, patch: Value) {
if let Some(patch_url) = patch["url"].as_str() {
self.url = patch_url.into();
}
if let Some(patch_body) = patch.get("body") {
json_patch::merge(&mut self.body, patch_body)
}
if let Some(patch_headers) = patch["headers"].as_object() {
for (key, value) in patch_headers {
if let Some(value) = value.as_str() {
self.header(key, value)
} else if value.is_null() {
self.headers.swap_remove(key);
}
}
}
}
}
#[derive(Debug)]
pub struct ChatCompletionsData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub functions: Option<Vec<FunctionDeclaration>>,
pub stream: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionsOutput {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub id: Option<String>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
}
impl ChatCompletionsOutput {
pub fn new(text: &str) -> Self {
Self {
text: text.to_string(),
..Default::default()
}
}
}
#[derive(Debug)]
pub struct EmbeddingsData {
pub texts: Vec<String>,
pub query: bool,
}
impl EmbeddingsData {
pub fn new(texts: Vec<String>, query: bool) -> Self {
Self { texts, query }
}
}
pub type EmbeddingsOutput = Vec<Vec<f32>>;
#[derive(Debug)]
pub struct RerankData {
pub query: String,
pub documents: Vec<String>,
pub top_n: usize,
}
impl RerankData {
pub fn new(query: String, documents: Vec<String>, top_n: usize) -> Self {
Self {
query,
documents,
top_n,
}
}
}
pub type RerankOutput = Vec<RerankResult>;
#[derive(Debug, Deserialize)]
pub struct RerankResult {
pub index: usize,
pub relevance_score: f64,
}
pub type PromptAction<'a> = (&'a str, &'a str, Option<&'a str>);
pub async fn create_config(
prompts: &[PromptAction<'static>],
client: &str,
) -> Result<(String, Value)> {
let mut config = json!({
"type": client,
});
for (key, desc, help_message) in prompts {
let env_name = format!("{client}_{key}").to_ascii_uppercase();
let required = std::env::var(&env_name).is_err();
let value = prompt_input_string(desc, required, *help_message)?;
if !value.is_empty() {
config[key] = value.into();
}
}
let model = set_client_models_config(&mut config, client).await?;
let clients = json!(vec![config]);
Ok((model, clients))
}
pub async fn create_openai_compatible_client_config(
client: &str,
) -> Result<Option<(String, Value)>> {
let api_base = OPENAI_COMPATIBLE_PROVIDERS
.into_iter()
.find(|(name, _)| client == *name)
.map(|(_, api_base)| api_base)
.unwrap_or("http(s)://{API_ADDR}/v1");
let name = if client == OpenAICompatibleClient::NAME {
let value = prompt_input_string("Provider Name", true, None)?;
value.replace(' ', "-")
} else {
client.to_string()
};
let mut config = json!({
"type": OpenAICompatibleClient::NAME,
"name": &name,
});
let api_base = if api_base.contains('{') {
prompt_input_string("API Base", true, Some(&format!("e.g. {api_base}")))?
} else {
api_base.to_string()
};
config["api_base"] = api_base.into();
let api_key = prompt_input_string("API Key", false, None)?;
if !api_key.is_empty() {
config["api_key"] = api_key.into();
}
let model = set_client_models_config(&mut config, &name).await?;
let clients = json!(vec![config]);
Ok(Some((model, clients)))
}
pub async fn call_chat_completions(
input: &Input,
print: bool,
extract_code: bool,
client: &dyn Client,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let ret = abortable_run_with_spinner(
client.chat_completions(input.clone()),
"Generating",
abort_signal,
)
.await;
match ret {
Ok(ret) => {
let ChatCompletionsOutput {
mut text,
tool_calls,
..
} = ret;
if !text.is_empty() {
if extract_code {
text = extract_code_block(&strip_think_tag(&text)).to_string();
}
if print {
client.global_config().read().print_markdown(&text)?;
}
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
}
Err(err) => Err(err),
}
}
pub async fn call_chat_completions_streaming(
input: &Input,
client: &dyn Client,
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel();
let mut handler = SseHandler::new(tx, abort_signal.clone());
let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, client.global_config(), abort_signal.clone()),
);
if handler.abort().aborted() {
bail!("Aborted.");
}
render_ret?;
let (text, tool_calls) = handler.take();
match send_ret {
Ok(_) => {
if !text.is_empty() && !text.ends_with('\n') {
println!();
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
}
Err(err) => {
if !text.is_empty() {
println!();
}
Err(err)
}
}
}
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
bail!("The client doesn't support embeddings api")
}
pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
bail!("The client doesn't support rerank api")
}
pub async fn noop_rerank(_builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
bail!("The client doesn't support rerank api")
}
pub fn catch_error(data: &Value, status: u16) -> Result<()> {
if (200..300).contains(&status) {
return Ok(());
}
debug!("Invalid response, status: {status}, data: {data}");
if let Some(error) = data["error"].as_object() {
if let (Some(typ), Some(message)) = (
json_str_from_map(error, "type"),
json_str_from_map(error, "message"),
) {
bail!("{message} (type: {typ})");
} else if let (Some(typ), Some(message)) = (
json_str_from_map(error, "code"),
json_str_from_map(error, "message"),
) {
bail!("{message} (code: {typ})");
}
} else if let Some(error) = data["errors"][0].as_object() {
if let (Some(code), Some(message)) = (
error.get("code").and_then(|v| v.as_u64()),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {code})")
}
} else if let Some(error) = data[0]["error"].as_object() {
if let (Some(status), Some(message)) = (
json_str_from_map(error, "status"),
json_str_from_map(error, "message"),
) {
bail!("{message} (status: {status})")
}
} else if let (Some(detail), Some(status)) = (data["detail"].as_str(), data["status"].as_i64())
{
bail!("{detail} (status: {status})");
} else if let Some(error) = data["error"].as_str() {
bail!("{error}");
} else if let Some(message) = data["message"].as_str() {
bail!("{message}");
}
bail!("Invalid response data: {data} (status: {status})");
}
pub fn json_str_from_map<'a>(
map: &'a serde_json::Map<String, Value>,
field_name: &str,
) -> Option<&'a str> {
map.get(field_name).and_then(|v| v.as_str())
}
async fn set_client_models_config(client_config: &mut Value, client: &str) -> Result<String> {
if let Some(provider) = ALL_PROVIDER_MODELS.iter().find(|v| v.provider == client) {
let models: Vec<String> = provider
.models
.iter()
.filter(|v| v.model_type == "chat")
.map(|v| v.name.clone())
.collect();
let model_name = select_model(models)?;
return Ok(format!("{client}:{model_name}"));
}
let mut model_names = vec![];
if let (Some(true), Some(api_base), api_key) = (
client_config["type"]
.as_str()
.map(|v| v == OpenAICompatibleClient::NAME),
client_config["api_base"].as_str(),
client_config["api_key"]
.as_str()
.map(|v| v.to_string())
.or_else(|| {
let env_name = format!("{client}_api_key").to_ascii_uppercase();
std::env::var(&env_name).ok()
}),
) {
match abortable_run_with_spinner(
fetch_models(api_base, api_key.as_deref()),
"Fetching models",
create_abort_signal(),
)
.await
{
Ok(fetched_models) => {
model_names = MultiSelect::new("LLMs to include (required):", fetched_models)
.with_validator(|list: &[ListOption<&String>]| {
if list.is_empty() {
Ok(Validation::Invalid(
"At least one item must be selected".into(),
))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
Err(err) => {
eprintln!("✗ Fetch models failed: {err}");
}
}
}
if model_names.is_empty() {
model_names = prompt_input_string(
"LLMs to add",
true,
Some("Separated by commas, e.g. llama3.3,qwen2.5"),
)?
.split(',')
.filter_map(|v| {
let v = v.trim();
if v.is_empty() {
None
} else {
Some(v.to_string())
}
})
.collect::<Vec<_>>();
}
if model_names.is_empty() {
bail!("No models");
}
let models: Vec<Value> = model_names
.iter()
.map(|v| {
let l = v.to_lowercase();
if l.contains("rank") {
json!({
"name": v,
"type": "reranker",
})
} else if let Ok(true) = EMBEDDING_MODEL_RE.is_match(&l) {
json!({
"name": v,
"type": "embedding",
"default_chunk_size": 1000,
"max_batch_size": 100
})
} else if v.contains("vision") {
json!({
"name": v,
"supports_vision": true
})
} else {
json!({
"name": v,
})
}
})
.collect();
client_config["models"] = models.into();
let model_name = select_model(model_names)?;
Ok(format!("{client}:{model_name}"))
}
fn select_model(model_names: Vec<String>) -> Result<String> {
if model_names.is_empty() {
bail!("No models");
}
let model = if model_names.len() == 1 {
model_names[0].clone()
} else {
Select::new("Default Model (required):", model_names).prompt()?
};
Ok(model)
}
fn prompt_input_string(desc: &str, required: bool, help_message: Option<&str>) -> Result<String> {
let desc = if required {
format!("{desc} (required):")
} else {
format!("{desc} (optional):")
};
let mut text = Text::new(&desc);
if required {
text = text.with_validator(required!("This field is required"))
}
if let Some(help_message) = help_message {
text = text.with_help_message(help_message);
}
let text = text.prompt()?;
Ok(text)
}
+136
View File
@@ -0,0 +1,136 @@
use super::vertexai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
GeminiClient,
(
prepare_chat_completions,
gemini_chat_completions,
gemini_chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &GeminiClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
let url = format!(
"{}/models/{}:{}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
func
);
let body = gemini_build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.header("x-goog-api-key", api_key);
Ok(request_data)
}
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!(
"{}/models/{}:batchEmbedContents?key={}",
api_base.trim_end_matches('/'),
self_.model.real_name(),
api_key
);
let model_id = format!("models/{}", self_.model.real_name());
let requests: Vec<_> = data
.texts
.iter()
.map(|text| {
json!({
"model": model_id,
"content": {
"parts": [
{
"text": text
}
]
},
})
})
.collect();
let body = json!({
"requests": requests,
});
let request_data = RequestData::new(url, body);
Ok(request_data)
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body
.embeddings
.into_iter()
.map(|embedding| embedding.values)
.collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
embeddings: Vec<EmbeddingsResBodyEmbedding>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
values: Vec<f32>,
}
+245
View File
@@ -0,0 +1,245 @@
#[macro_export]
macro_rules! register_client {
(
$(($module:ident, $name:literal, $config:ident, $client:ident),)+
) => {
$(
mod $module;
)+
$(
use self::$module::$config;
)+
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "type")]
pub enum ClientConfig {
$(
#[serde(rename = $name)]
$config($config),
)+
#[serde(other)]
Unknown,
}
$(
#[derive(Debug)]
pub struct $client {
global_config: $crate::config::GlobalConfig,
config: $config,
model: $crate::client::Model,
}
impl $client {
pub const NAME: &'static str = $name;
pub fn init(global_config: &$crate::config::GlobalConfig, model: &$crate::client::Model) -> Option<Box<dyn Client>> {
let config = global_config.read().clients.iter().find_map(|client_config| {
if let ClientConfig::$config(c) = client_config {
if Self::name(c) == model.client_name() {
return Some(c.clone())
}
}
None
})?;
Some(Box::new(Self {
global_config: global_config.clone(),
config,
model: model.clone(),
}))
}
pub fn list_models(local_config: &$config) -> Vec<Model> {
let client_name = Self::name(local_config);
if local_config.models.is_empty() {
if let Some(v) = $crate::client::ALL_PROVIDER_MODELS.iter().find(|v| {
v.provider == $name ||
($name == OpenAICompatibleClient::NAME
&& local_config.name.as_ref().map(|name| name.starts_with(&v.provider)).unwrap_or_default())
}) {
return Model::from_config(client_name, &v.models);
}
vec![]
} else {
Model::from_config(client_name, &local_config.models)
}
}
pub fn name(local_config: &$config) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
}
)+
pub fn init_client(config: &$crate::config::GlobalConfig, model: Option<$crate::client::Model>) -> anyhow::Result<Box<dyn Client>> {
let model = model.unwrap_or_else(|| config.read().model.clone());
None
$(.or_else(|| $client::init(config, &model)))+
.ok_or_else(|| {
anyhow::anyhow!("Invalid model '{}'", model.id())
})
}
pub fn list_client_types() -> Vec<&'static str> {
let mut client_types: Vec<_> = vec![$($client::NAME,)+];
client_types.extend($crate::client::OPENAI_COMPATIBLE_PROVIDERS.iter().map(|(name, _)| *name));
client_types
}
pub async fn create_client_config(client: &str) -> anyhow::Result<(String, serde_json::Value)> {
$(
if client == $client::NAME && client != $crate::client::OpenAICompatibleClient::NAME {
return create_config(&$client::PROMPTS, $client::NAME).await
}
)+
if let Some(ret) = create_openai_compatible_client_config(client).await? {
return Ok(ret);
}
anyhow::bail!("Unknown client '{}'", client)
}
static ALL_CLIENT_NAMES: std::sync::OnceLock<Vec<String>> = std::sync::OnceLock::new();
pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> {
let names = ALL_CLIENT_NAMES.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+
ClientConfig::Unknown => vec![],
})
.collect()
});
names.iter().collect()
}
static ALL_MODELS: std::sync::OnceLock<Vec<$crate::client::Model>> = std::sync::OnceLock::new();
pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> {
let models = ALL_MODELS.get_or_init(|| {
config
.clients
.iter()
.flat_map(|v| match v {
$(ClientConfig::$config(c) => $client::list_models(c),)+
ClientConfig::Unknown => vec![],
})
.collect()
});
models.iter().collect()
}
pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> {
list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect()
}
};
}
#[macro_export]
macro_rules! client_common_fns {
() => {
fn global_config(&self) -> &$crate::config::GlobalConfig {
&self.global_config
}
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
self.config.extra.as_ref()
}
fn patch_config(&self) -> Option<&$crate::client::RequestPatch> {
self.config.patch.as_ref()
}
fn name(&self) -> &str {
Self::name(&self.config)
}
fn model(&self) -> &Model {
&self.model
}
fn model_mut(&mut self) -> &mut Model {
&mut self.model
}
};
}
#[macro_export]
macro_rules! impl_client_trait {
(
$client:ident,
($prepare_chat_completions:path, $chat_completions:path, $chat_completions_streaming:path),
($prepare_embeddings:path, $embeddings:path),
($prepare_rerank:path, $rerank:path),
) => {
#[async_trait::async_trait]
impl $crate::client::Client for $crate::client::$client {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &reqwest::Client,
data: $crate::client::ChatCompletionsData,
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions(builder, self.model()).await
}
async fn chat_completions_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::client::SseHandler,
data: $crate::client::ChatCompletionsData,
) -> Result<()> {
let request_data = $prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data);
$chat_completions_streaming(builder, handler, self.model()).await
}
async fn embeddings_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::EmbeddingsData,
) -> Result<$crate::client::EmbeddingsOutput> {
let request_data = $prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data);
$embeddings(builder, self.model()).await
}
async fn rerank_inner(
&self,
client: &reqwest::Client,
data: &$crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let request_data = $prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data);
$rerank(builder, self.model()).await
}
}
};
}
#[macro_export]
macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> anyhow::Result<String> {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
std::env::var(&env_name)
.ok()
.or_else(|| self.config.$field_name.clone())
.ok_or_else(|| anyhow::anyhow!("Miss '{}'", stringify!($field_name)))
}
};
}
#[macro_export]
macro_rules! unsupported_model {
($name:expr) => {
anyhow::bail!("Unsupported model '{}'", $name)
};
}
+235
View File
@@ -0,0 +1,235 @@
use super::Model;
use crate::{function::ToolResult, multiline_text, utils::dimmed_text};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
pub content: MessageContent,
}
impl Default for Message {
fn default() -> Self {
Self {
role: MessageRole::User,
content: MessageContent::Text(String::new()),
}
}
}
impl Message {
pub fn new(role: MessageRole, content: MessageContent) -> Self {
Self { role, content }
}
pub fn merge_system(&mut self, system: MessageContent) {
match (&mut self.content, system) {
(MessageContent::Text(text), MessageContent::Text(system_text)) => {
self.content = MessageContent::Array(vec![
MessageContentPart::Text { text: system_text },
MessageContentPart::Text {
text: text.to_string(),
},
])
}
(MessageContent::Array(list), MessageContent::Text(system_text)) => {
list.insert(0, MessageContentPart::Text { text: system_text })
}
(MessageContent::Text(text), MessageContent::Array(mut system_list)) => {
system_list.push(MessageContentPart::Text {
text: text.to_string(),
});
self.content = MessageContent::Array(system_list);
}
(MessageContent::Array(list), MessageContent::Array(mut system_list)) => {
system_list.append(list);
self.content = MessageContent::Array(system_list);
}
_ => {}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
Assistant,
User,
Tool,
}
#[allow(dead_code)]
impl MessageRole {
pub fn is_system(&self) -> bool {
matches!(self, MessageRole::System)
}
pub fn is_user(&self) -> bool {
matches!(self, MessageRole::User)
}
pub fn is_assistant(&self) -> bool {
matches!(self, MessageRole::Assistant)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
ToolCalls(MessageContentToolCalls),
}
impl MessageContent {
pub fn render_input(
&self,
resolve_url_fn: impl Fn(&str) -> String,
agent_info: &Option<(String, Vec<String>)>,
) -> String {
match self {
MessageContent::Text(text) => multiline_text(text),
MessageContent::Array(list) => {
let (mut concated_text, mut files) = (String::new(), vec![]);
for item in list {
match item {
MessageContentPart::Text { text } => {
concated_text = format!("{concated_text} {text}")
}
MessageContentPart::ImageUrl { image_url } => {
files.push(resolve_url_fn(&image_url.url))
}
}
}
if !concated_text.is_empty() {
concated_text = format!(" -- {}", multiline_text(&concated_text))
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut lines = vec![];
if !text.is_empty() {
lines.push(text.clone())
}
for tool_result in tool_results {
let mut parts = vec!["Call".to_string()];
if let Some((agent_name, functions)) = agent_info {
if functions.contains(&tool_result.call.name) {
parts.push(agent_name.clone())
}
}
parts.push(tool_result.call.name.clone());
parts.push(tool_result.call.arguments.to_string());
lines.push(dimmed_text(&parts.join(" ")));
}
lines.join("\n")
}
}
}
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
match self {
MessageContent::Text(text) => *text = replace_fn(text),
MessageContent::Array(list) => {
if list.is_empty() {
list.push(MessageContentPart::Text {
text: replace_fn(""),
})
} else if let Some(MessageContentPart::Text { text }) = list.get_mut(0) {
*text = replace_fn(text)
}
}
MessageContent::ToolCalls(_) => {}
}
}
pub fn to_text(&self) -> String {
match self {
MessageContent::Text(text) => text.to_string(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
if let MessageContentPart::Text { text } = item {
parts.push(text.clone())
}
}
parts.join("\n\n")
}
MessageContent::ToolCalls(_) => String::new(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContentToolCalls {
pub tool_results: Vec<ToolResult>,
pub text: String,
pub sequence: bool,
}
impl MessageContentToolCalls {
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
Self {
tool_results,
text,
sequence: false,
}
}
pub fn merge(&mut self, tool_results: Vec<ToolResult>, _text: String) {
self.tool_results.extend(tool_results);
self.text.clear();
self.sequence = true;
}
}
pub fn patch_messages(messages: &mut Vec<Message>, model: &Model) {
if messages.is_empty() {
return;
}
if let Some(prefix) = model.system_prompt_prefix() {
if messages[0].role.is_system() {
messages[0].merge_system(MessageContent::Text(prefix.to_string()));
} else {
messages.insert(
0,
Message {
role: MessageRole::System,
content: MessageContent::Text(prefix.to_string()),
},
);
}
}
if model.no_system_message() && messages[0].role.is_system() {
let system_message = messages.remove(0);
if let (Some(message), system) = (messages.get_mut(0), system_message.content) {
message.merge_system(system);
}
}
}
pub fn extract_system_message(messages: &mut Vec<Message>) -> Option<String> {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
return Some(system_message.content.to_text());
}
None
}
+62
View File
@@ -0,0 +1,62 @@
mod access_token;
mod common;
mod message;
#[macro_use]
mod macros;
mod model;
mod stream;
pub use crate::function::ToolCall;
pub use common::*;
pub use message::*;
pub use model::*;
pub use stream::*;
register_client!(
(openai, "openai", OpenAIConfig, OpenAIClient),
(
openai_compatible,
"openai-compatible",
OpenAICompatibleConfig,
OpenAICompatibleClient
),
(gemini, "gemini", GeminiConfig, GeminiClient),
(claude, "claude", ClaudeConfig, ClaudeClient),
(cohere, "cohere", CohereConfig, CohereClient),
(
azure_openai,
"azure-openai",
AzureOpenAIConfig,
AzureOpenAIClient
),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient),
);
pub const OPENAI_COMPATIBLE_PROVIDERS: [(&str, &str); 18] = [
("ai21", "https://api.ai21.com/studio/v1"),
(
"cloudflare",
"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1",
),
("deepinfra", "https://api.deepinfra.com/v1/openai"),
("deepseek", "https://api.deepseek.com"),
("ernie", "https://qianfan.baidubce.com/v2"),
("github", "https://models.inference.ai.azure.com"),
("groq", "https://api.groq.com/openai/v1"),
("hunyuan", "https://api.hunyuan.cloud.tencent.com/v1"),
("minimax", "https://api.minimax.chat/v1"),
("mistral", "https://api.mistral.ai/v1"),
("moonshot", "https://api.moonshot.cn/v1"),
("openrouter", "https://openrouter.ai/api/v1"),
("perplexity", "https://api.perplexity.ai"),
(
"qianwen",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
),
("xai", "https://api.x.ai/v1"),
("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
// RAG-dedicated
("jina", "https://api.jina.ai/v1"),
("voyageai", "https://api.voyageai.com/v1"),
];
+407
View File
@@ -0,0 +1,407 @@
use super::{
list_all_models, list_client_names,
message::{Message, MessageContent, MessageContentPart},
ApiPatch, MessageContentToolCalls, RequestPatch,
};
use crate::config::Config;
use crate::utils::{estimate_token_length, strip_think_tag};
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Display;
const PER_MESSAGES_TOKENS: usize = 5;
const BASIS_TOKENS: usize = 2;
#[derive(Debug, Clone)]
pub struct Model {
client_name: String,
data: ModelData,
}
impl Default for Model {
fn default() -> Self {
Model::new("", "")
}
}
impl Model {
pub fn new(client_name: &str, name: &str) -> Self {
Self {
client_name: client_name.into(),
data: ModelData::new(name),
}
}
pub fn from_config(client_name: &str, models: &[ModelData]) -> Vec<Self> {
models
.iter()
.map(|v| Model {
client_name: client_name.to_string(),
data: v.clone(),
})
.collect()
}
pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result<Self> {
let models = list_all_models(config);
let (client_name, model_name) = match model_id.split_once(':') {
Some((client_name, model_name)) => {
if model_name.is_empty() {
(client_name, None)
} else {
(client_name, Some(model_name))
}
}
None => (model_id, None),
};
match model_name {
Some(model_name) => {
if let Some(model) = models.iter().find(|v| v.id() == model_id) {
if model.model_type() == model_type {
return Ok((*model).clone());
} else {
bail!("Model '{model_id}' is not a {model_type} model")
}
}
if list_client_names(config)
.into_iter()
.any(|v| *v == client_name)
&& model_type.can_create_from_name()
{
let mut new_model = Self::new(client_name, model_name);
new_model.data.model_type = model_type.to_string();
return Ok(new_model);
}
}
None => {
if let Some(found) = models
.iter()
.find(|v| v.client_name == client_name && v.model_type() == model_type)
{
return Ok((*found).clone());
}
}
};
bail!("Unknown {model_type} model '{model_id}'")
}
pub fn id(&self) -> String {
if self.data.name.is_empty() {
self.client_name.to_string()
} else {
format!("{}:{}", self.client_name, self.data.name)
}
}
pub fn client_name(&self) -> &str {
&self.client_name
}
pub fn name(&self) -> &str {
&self.data.name
}
pub fn real_name(&self) -> &str {
self.data.real_name.as_deref().unwrap_or(&self.data.name)
}
pub fn model_type(&self) -> ModelType {
if self.data.model_type.starts_with("embed") {
ModelType::Embedding
} else if self.data.model_type.starts_with("rerank") {
ModelType::Reranker
} else {
ModelType::Chat
}
}
pub fn data(&self) -> &ModelData {
&self.data
}
pub fn data_mut(&mut self) -> &mut ModelData {
&mut self.data
}
pub fn description(&self) -> String {
match self.model_type() {
ModelType::Chat => {
let ModelData {
max_input_tokens,
max_output_tokens,
input_price,
output_price,
supports_vision,
supports_function_calling,
..
} = &self.data;
let max_input_tokens = stringify_option_value(max_input_tokens);
let max_output_tokens = stringify_option_value(max_output_tokens);
let input_price = stringify_option_value(input_price);
let output_price = stringify_option_value(output_price);
let mut capabilities = vec![];
if *supports_vision {
capabilities.push('👁');
};
if *supports_function_calling {
capabilities.push('⚒');
};
let capabilities: String = capabilities
.into_iter()
.map(|v| format!("{v} "))
.collect::<Vec<String>>()
.join("");
format!(
"{max_input_tokens:>8} / {max_output_tokens:>8} | {input_price:>6} / {output_price:>6} {capabilities:>6}"
)
}
ModelType::Embedding => {
let ModelData {
input_price,
max_tokens_per_chunk,
max_batch_size,
..
} = &self.data;
let max_tokens = stringify_option_value(max_tokens_per_chunk);
let max_batch = stringify_option_value(max_batch_size);
let price = stringify_option_value(input_price);
format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}")
}
ModelType::Reranker => String::new(),
}
}
pub fn patch(&self) -> Option<&Value> {
self.data.patch.as_ref()
}
pub fn max_input_tokens(&self) -> Option<usize> {
self.data.max_input_tokens
}
pub fn max_output_tokens(&self) -> Option<isize> {
self.data.max_output_tokens
}
pub fn no_stream(&self) -> bool {
self.data.no_stream
}
pub fn no_system_message(&self) -> bool {
self.data.no_system_message
}
pub fn system_prompt_prefix(&self) -> Option<&str> {
self.data.system_prompt_prefix.as_deref()
}
pub fn max_tokens_per_chunk(&self) -> Option<usize> {
self.data.max_tokens_per_chunk
}
pub fn default_chunk_size(&self) -> usize {
self.data.default_chunk_size.unwrap_or(1000)
}
pub fn max_batch_size(&self) -> Option<usize> {
self.data.max_batch_size
}
pub fn max_tokens_param(&self) -> Option<isize> {
if self.data.require_max_tokens {
self.data.max_output_tokens
} else {
None
}
}
pub fn set_max_tokens(
&mut self,
max_output_tokens: Option<isize>,
require_max_tokens: bool,
) -> &mut Self {
match max_output_tokens {
None | Some(0) => self.data.max_output_tokens = None,
_ => self.data.max_output_tokens = max_output_tokens,
}
self.data.require_max_tokens = require_max_tokens;
self
}
pub fn messages_tokens(&self, messages: &[Message]) -> usize {
let messages_len = messages.len();
messages
.iter()
.enumerate()
.map(|(i, v)| match &v.content {
MessageContent::Text(text) => {
if v.role.is_assistant() && i != messages_len - 1 {
estimate_token_length(&strip_think_tag(text))
} else {
estimate_token_length(text)
}
}
MessageContent::Array(list) => list
.iter()
.map(|v| match v {
MessageContentPart::Text { text } => estimate_token_length(text),
MessageContentPart::ImageUrl { .. } => 0,
})
.sum(),
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
estimate_token_length(text)
+ tool_results
.iter()
.map(|v| {
serde_json::to_string(v)
.map(|v| estimate_token_length(&v))
.unwrap_or_default()
})
.sum::<usize>()
}
})
.sum()
}
pub fn total_tokens(&self, messages: &[Message]) -> usize {
if messages.is_empty() {
return 0;
}
let num_messages = messages.len();
let message_tokens = self.messages_tokens(messages);
if messages[num_messages - 1].role.is_user() {
num_messages * PER_MESSAGES_TOKENS + message_tokens
} else {
(num_messages - 1) * PER_MESSAGES_TOKENS + message_tokens
}
}
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
if let Some(max_input_tokens) = self.data.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max_input_tokens limit")
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelData {
pub name: String,
#[serde(default = "default_model_type", rename = "type")]
pub model_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub real_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_input_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_price: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_price: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub patch: Option<Value>,
// chat-only properties
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<isize>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub require_max_tokens: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub supports_vision: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub supports_function_calling: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
no_stream: bool,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
no_system_message: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system_prompt_prefix: Option<String>,
// embedding-only properties
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens_per_chunk: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_chunk_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_batch_size: Option<usize>,
}
impl ModelData {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
model_type: default_model_type(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderModels {
pub provider: String,
pub models: Vec<ModelData>,
}
fn default_model_type() -> String {
"chat".into()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelType {
Chat,
Embedding,
Reranker,
}
impl Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::Chat => write!(f, "chat"),
ModelType::Embedding => write!(f, "embedding"),
ModelType::Reranker => write!(f, "reranker"),
}
}
}
impl ModelType {
pub fn can_create_from_name(self) -> bool {
match self {
ModelType::Chat => true,
ModelType::Embedding => false,
ModelType::Reranker => true,
}
}
pub fn api_name(self) -> &'static str {
match self {
ModelType::Chat => "chat_completions",
ModelType::Embedding => "embeddings",
ModelType::Reranker => "rerank",
}
}
pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> {
match self {
ModelType::Chat => patch.chat_completions.as_ref(),
ModelType::Embedding => patch.embeddings.as_ref(),
ModelType::Reranker => patch.rerank.as_ref(),
}
}
}
fn stringify_option_value<T>(value: &Option<T>) -> String
where
T: Display,
{
match value {
Some(value) => value.to_string(),
None => "-".to_string(),
}
}
+408
View File
@@ -0,0 +1,408 @@
use super::*;
use crate::utils::strip_think_tag;
use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
const API_BASE: &str = "https://api.openai.com/v1";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OpenAIConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
pub organization_id: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl OpenAIClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);
pub const PROMPTS: [PromptAction<'static>; 1] = [("api_key", "API Key", None)];
}
impl_client_trait!(
OpenAIClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(noop_prepare_rerank, noop_rerank),
);
fn prepare_chat_completions(
self_: &OpenAIClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
if let Some(organization_id) = &self_.config.organization_id {
request_data.header("OpenAI-Organization", organization_id);
}
Ok(request_data)
}
fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());
let url = format!("{api_base}/embeddings");
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
if let Some(organization_id) = &self_.config.organization_id {
request_data.header("OpenAI-Organization", organization_id);
}
Ok(request_data)
}
pub async fn openai_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
openai_extract_chat_completions(&data)
}
pub async fn openai_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut call_id = String::new();
let mut function_name = String::new();
let mut function_arguments = String::new();
let mut function_id = String::new();
let mut reasoning_state = 0;
let handle = |message: SseMessage| -> Result<bool> {
if message.data == "[DONE]" {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
normalize_function_id(&function_id),
))?;
}
return Ok(true);
}
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
if let Some(text) = data["choices"][0]["delta"]["content"]
.as_str()
.filter(|v| !v.is_empty())
{
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
handler.text(text)?;
} else if let Some(text) = data["choices"][0]["delta"]["reasoning_content"]
.as_str()
.or_else(|| data["choices"][0]["delta"]["reasoning"].as_str())
.filter(|v| !v.is_empty())
{
if reasoning_state == 0 {
handler.text("<think>\n")?;
reasoning_state = 1;
}
handler.text(text)?;
}
if let (Some(function), index, id) = (
data["choices"][0]["delta"]["tool_calls"][0]["function"].as_object(),
data["choices"][0]["delta"]["tool_calls"][0]["index"].as_u64(),
data["choices"][0]["delta"]["tool_calls"][0]["id"]
.as_str()
.filter(|v| !v.is_empty()),
) {
if reasoning_state == 1 {
handler.text("\n</think>\n\n")?;
reasoning_state = 0;
}
let maybe_call_id = format!("{}/{}", id.unwrap_or_default(), index.unwrap_or_default());
if maybe_call_id != call_id && maybe_call_id.len() >= call_id.len() {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
normalize_function_id(&function_id),
))?;
}
function_name.clear();
function_arguments.clear();
function_id.clear();
call_id = maybe_call_id;
}
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
if name.starts_with(&function_name) {
function_name = name.to_string();
} else {
function_name.push_str(name);
}
}
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
function_arguments.push_str(arguments);
}
if let Some(id) = id {
function_id = id.to_string();
}
}
Ok(false)
};
sse_stream(builder, handle).await
}
pub async fn openai_embeddings(
builder: RequestBuilder,
_model: &Model,
) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
data: Vec<EmbeddingsResBodyEmbedding>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
embedding: Vec<f32>,
}
pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
} = data;
let messages_len = messages.len();
let messages: Vec<Value> = messages
.into_iter()
.enumerate()
.flat_map(|(i, message)| {
let Message { role, content } = message;
match content {
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results,
text: _,
sequence,
}) => {
if !sequence {
let tool_calls: Vec<_> = tool_results
.iter()
.map(|tool_result| {
json!({
"id": tool_result.call.id,
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments.to_string(),
},
})
})
.collect();
let mut messages = vec![
json!({ "role": MessageRole::Assistant, "tool_calls": tool_calls }),
];
for tool_result in tool_results {
messages.push(json!({
"role": "tool",
"content": tool_result.output.to_string(),
"tool_call_id": tool_result.call.id,
}));
}
messages
} else {
tool_results.into_iter().flat_map(|tool_result| {
vec![
json!({
"role": MessageRole::Assistant,
"tool_calls": [
{
"id": tool_result.call.id,
"type": "function",
"function": {
"name": tool_result.call.name,
"arguments": tool_result.call.arguments.to_string(),
},
}
]
}),
json!({
"role": "tool",
"content": tool_result.output.to_string(),
"tool_call_id": tool_result.call.id,
})
]
}).collect()
}
}
MessageContent::Text(text) if role.is_assistant() && i != messages_len - 1 => {
vec![json!({ "role": role, "content": strip_think_tag(&text) }
)]
}
_ => vec![json!({ "role": role, "content": content })],
}
})
.collect();
let mut body = json!({
"model": &model.real_name(),
"messages": messages,
});
if let Some(v) = model.max_tokens_param() {
if model
.patch()
.and_then(|v| v.get("body").and_then(|v| v.get("max_tokens")))
== Some(&Value::Null)
{
body["max_completion_tokens"] = v.into();
} else {
body["max_tokens"] = v.into();
}
}
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
if let Some(functions) = functions {
body["tools"] = functions
.iter()
.map(|v| {
json!({
"type": "function",
"function": v,
})
})
.collect();
}
body
}
pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value {
json!({
"input": data.texts,
"model": model.real_name()
})
}
pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["choices"][0]["message"]["content"]
.as_str()
.unwrap_or_default();
let reasoning = data["choices"][0]["message"]["reasoning_content"]
.as_str()
.or_else(|| data["choices"][0]["message"]["reasoning"].as_str())
.unwrap_or_default()
.trim();
let mut tool_calls = vec![];
if let Some(calls) = data["choices"][0]["message"]["tool_calls"].as_array() {
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
call["function"]["name"].as_str(),
call["function"]["arguments"].as_str(),
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
arguments,
Some(id.to_string()),
));
}
}
};
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let text = if !reasoning.is_empty() {
format!("<think>\n{reasoning}\n</think>\n\n{text}")
} else {
text.to_string()
};
let output = ChatCompletionsOutput {
text,
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["usage"]["prompt_tokens"].as_u64(),
output_tokens: data["usage"]["completion_tokens"].as_u64(),
};
Ok(output)
}
fn normalize_function_id(value: &str) -> Option<String> {
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
+162
View File
@@ -0,0 +1,162 @@
use super::openai::*;
use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAICompatibleConfig {
pub name: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl OpenAICompatibleClient {
config_get_fn!(api_base, get_api_base);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 0] = [];
}
impl_client_trait!(
OpenAICompatibleClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(prepare_rerank, generic_rerank),
);
fn prepare_chat_completions(
self_: &OpenAICompatibleClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = format!("{api_base}/chat/completions");
let body = openai_build_chat_completions_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn prepare_embeddings(
self_: &OpenAICompatibleClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = format!("{api_base}/embeddings");
let body = openai_build_embeddings_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
let url = if self_.name().starts_with("ernie") {
format!("{api_base}/rerankers")
} else {
format!("{api_base}/rerank")
};
let body = generic_build_rerank_body(data, &self_.model);
let mut request_data = RequestData::new(url, body);
if let Some(api_key) = api_key {
request_data.bearer_auth(api_key);
}
Ok(request_data)
}
fn get_api_base_ext(self_: &OpenAICompatibleClient) -> Result<String> {
let api_base = match self_.get_api_base() {
Ok(v) => v,
Err(err) => {
match OPENAI_COMPATIBLE_PROVIDERS
.into_iter()
.find_map(|(name, api_base)| {
if name == self_.model.client_name() {
Some(api_base.to_string())
} else {
None
}
}) {
Some(v) => v,
None => return Err(err),
}
}
};
Ok(api_base.trim_end_matches('/').to_string())
}
pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<RerankOutput> {
let res = builder.send().await?;
let status = res.status();
let mut data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
if data.get("results").is_none() && data.get("data").is_some() {
if let Some(data_obj) = data.as_object_mut() {
if let Some(value) = data_obj.remove("data") {
data_obj.insert("results".to_string(), value);
}
}
}
let res_body: GenericRerankResBody =
serde_json::from_value(data).context("Invalid rerank data")?;
Ok(res_body.results)
}
#[derive(Deserialize)]
pub struct GenericRerankResBody {
pub results: RerankOutput,
}
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
let RerankData {
query,
documents,
top_n,
} = data;
let mut body = json!({
"model": model.real_name(),
"query": query,
"documents": documents,
});
if model.client_name().starts_with("voyageai") {
body["top_k"] = (*top_n).into()
} else {
body["top_n"] = (*top_n).into()
}
body
}
+296
View File
@@ -0,0 +1,296 @@
use super::{catch_error, ToolCall};
use crate::utils::AbortSignal;
use anyhow::{anyhow, bail, Context, Result};
use futures_util::{Stream, StreamExt};
use reqwest::RequestBuilder;
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
use serde_json::Value;
use tokio::sync::mpsc::UnboundedSender;
pub struct SseHandler {
sender: UnboundedSender<SseEvent>,
abort_signal: AbortSignal,
buffer: String,
tool_calls: Vec<ToolCall>,
}
impl SseHandler {
pub fn new(sender: UnboundedSender<SseEvent>, abort_signal: AbortSignal) -> Self {
Self {
sender,
abort_signal,
buffer: String::new(),
tool_calls: Vec::new(),
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
// debug!("HandleText: {}", text);
if text.is_empty() {
return Ok(());
}
self.buffer.push_str(text);
let ret = self
.sender
.send(SseEvent::Text(text.to_string()))
.with_context(|| "Failed to send SseEvent:Text");
if let Err(err) = ret {
if self.abort_signal.aborted() {
return Ok(());
}
return Err(err);
}
Ok(())
}
pub fn done(&mut self) {
// debug!("HandleDone");
let ret = self.sender.send(SseEvent::Done);
if ret.is_err() {
if self.abort_signal.aborted() {
return;
}
warn!("Failed to send SseEvent:Done");
}
}
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
// debug!("HandleCall: {:?}", call);
self.tool_calls.push(call);
Ok(())
}
pub fn abort(&self) -> AbortSignal {
self.abort_signal.clone()
}
pub fn tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}
pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..
} = self;
(buffer, tool_calls)
}
}
#[derive(Debug)]
pub enum SseEvent {
Text(String),
Done,
}
#[derive(Debug)]
pub struct SseMessage {
#[allow(unused)]
pub event: String,
pub data: String,
}
pub async fn sse_stream<F>(builder: RequestBuilder, mut handle: F) -> Result<()>
where
F: FnMut(SseMessage) -> Result<bool>,
{
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let message = SseMessage {
event: message.event,
data: message.data,
};
if handle(message)? {
break;
}
}
Err(err) => {
match err {
EventSourceError::StreamEnded => {}
EventSourceError::InvalidStatusCode(status, res) => {
let text = res.text().await?;
let data: Value = match text.parse() {
Ok(data) => data,
Err(_) => {
bail!(
"Invalid response data: {text} (status: {})",
status.as_u16()
);
}
};
catch_error(&data, status.as_u16())?;
}
EventSourceError::InvalidContentType(header_value, res) => {
let text = res.text().await?;
bail!(
"Invalid response event-stream. content-type: {}, data: {text}",
header_value.to_str().unwrap_or_default()
);
}
_ => {
bail!("{}", err);
}
}
es.close();
}
}
}
Ok(())
}
pub async fn json_stream<S, F, E>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, E>> + Unpin,
F: FnMut(&str) -> Result<()>,
E: std::error::Error,
{
let mut parser = JsonStreamParser::default();
let mut unparsed_bytes = vec![];
while let Some(chunk_bytes) = stream.next().await {
let chunk_bytes =
chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?;
unparsed_bytes.extend(chunk_bytes);
match std::str::from_utf8(&unparsed_bytes) {
Ok(text) => {
parser.process(text, &mut handle)?;
unparsed_bytes.clear();
}
Err(_) => {
continue;
}
}
}
if !unparsed_bytes.is_empty() {
let text = std::str::from_utf8(&unparsed_bytes)?;
parser.process(text, &mut handle)?;
}
Ok(())
}
#[derive(Debug, Default)]
struct JsonStreamParser {
buffer: Vec<char>,
cursor: usize,
start: Option<usize>,
balances: Vec<char>,
quoting: bool,
escape: bool,
}
impl JsonStreamParser {
fn process<F>(&mut self, text: &str, handle: &mut F) -> Result<()>
where
F: FnMut(&str) -> Result<()>,
{
self.buffer.extend(text.chars());
for i in self.cursor..self.buffer.len() {
let ch = self.buffer[i];
if self.quoting {
if ch == '\\' {
self.escape = !self.escape;
} else {
if !self.escape && ch == '"' {
self.quoting = false;
}
self.escape = false;
}
continue;
}
match ch {
'"' => {
self.quoting = true;
self.escape = false;
}
'{' => {
if self.balances.is_empty() {
self.start = Some(i);
}
self.balances.push(ch);
}
'[' => {
if self.start.is_some() {
self.balances.push(ch);
}
}
'}' => {
self.balances.pop();
if self.balances.is_empty() {
if let Some(start) = self.start.take() {
let value: String = self.buffer[start..=i].iter().collect();
handle(&value)?;
}
}
}
']' => {
self.balances.pop();
}
_ => {}
}
}
self.cursor = self.buffer.len();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures_util::stream;
use rand::Rng;
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
let mut rng = rand::rng();
let len = text.len();
let cut1 = rng.random_range(1..len - 1);
let cut2 = rng.random_range(cut1 + 1..len);
let chunk1 = text.as_bytes()[..cut1].to_vec();
let chunk2 = text.as_bytes()[cut1..cut2].to_vec();
let chunk3 = text.as_bytes()[cut2..].to_vec();
vec![chunk1, chunk2, chunk3]
}
macro_rules! assert_json_stream {
($input:expr, $output:expr) => {
let chunks: Vec<_> = split_chunks($input)
.into_iter()
.map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk)))
.collect();
let stream = stream::iter(chunks);
let mut output = vec![];
let ret = json_stream(stream, |data| {
output.push(data.to_string());
Ok(())
})
.await;
assert!(ret.is_ok());
assert_eq!($output.replace("\r\n", "\n"), output.join("\n"))
};
}
#[tokio::test]
async fn test_json_stream_ndjson() {
let data = r#"{"key": "value"}
{"key": "value2"}
{"key": "value3"}"#;
assert_json_stream!(data, data);
}
#[tokio::test]
async fn test_json_stream_array() {
let input = r#"[
{"key": "value"},
{"key": "value2"},
{"key": "value3"},"#;
let output = r#"{"key": "value"}
{"key": "value2"}
{"key": "value3"}"#;
assert_json_stream!(input, output);
}
}
+537
View File
@@ -0,0 +1,537 @@
use super::access_token::*;
use super::claude::*;
use super::openai::*;
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use chrono::{Duration, Utc};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{path::PathBuf, str::FromStr};
#[derive(Debug, Clone, Deserialize, Default)]
pub struct VertexAIConfig {
pub name: Option<String>,
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl VertexAIClient {
config_get_fn!(project_id, get_project_id);
config_get_fn!(location, get_location);
pub const PROMPTS: [PromptAction<'static>; 2] = [
("project_id", "Project ID", None),
("location", "Location", None),
];
}
#[async_trait::async_trait]
impl Client for VertexAIClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let model = self.model();
let model_category = ModelCategory::from_str(model.real_name())?;
let request_data = prepare_chat_completions(self, data, &model_category)?;
let builder = self.request_builder(client, request_data);
match model_category {
ModelCategory::Gemini => gemini_chat_completions(builder, model).await,
ModelCategory::Claude => claude_chat_completions(builder, model).await,
ModelCategory::Mistral => openai_chat_completions(builder, model).await,
}
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let model = self.model();
let model_category = ModelCategory::from_str(model.real_name())?;
let request_data = prepare_chat_completions(self, data, &model_category)?;
let builder = self.request_builder(client, request_data);
match model_category {
ModelCategory::Gemini => {
gemini_chat_completions_streaming(builder, handler, model).await
}
ModelCategory::Claude => {
claude_chat_completions_streaming(builder, handler, model).await
}
ModelCategory::Mistral => {
openai_chat_completions_streaming(builder, handler, model).await
}
}
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: &EmbeddingsData,
) -> Result<Vec<Vec<f32>>> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let request_data = prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data);
embeddings(builder, self.model()).await
}
}
fn prepare_chat_completions(
self_: &VertexAIClient,
data: ChatCompletionsData,
model_category: &ModelCategory,
) -> Result<RequestData> {
let project_id = self_.get_project_id()?;
let location = self_.get_location()?;
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
};
let model_name = self_.model.real_name();
let url = match model_category {
ModelCategory::Gemini => {
let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};
format!("{base_url}/google/models/{model_name}:{func}")
}
ModelCategory::Claude => {
format!("{base_url}/anthropic/models/{model_name}:streamRawPredict")
}
ModelCategory::Mistral => {
let func = match data.stream {
true => "streamRawPredict",
false => "rawPredict",
};
format!("{base_url}/mistralai/models/{model_name}:{func}")
}
};
let body = match model_category {
ModelCategory::Gemini => gemini_build_chat_completions_body(data, &self_.model)?,
ModelCategory::Claude => {
let mut body = claude_build_chat_completions_body(data, &self_.model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
}
body["anthropic_version"] = "vertex-2023-10-16".into();
body
}
ModelCategory::Mistral => {
let mut body = openai_build_chat_completions_body(data, &self_.model);
if let Some(body_obj) = body.as_object_mut() {
body_obj["model"] = strip_model_version(self_.model.real_name()).into();
}
body
}
};
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(access_token);
Ok(request_data)
}
fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let project_id = self_.get_project_id()?;
let location = self_.get_location()?;
let access_token = get_access_token(self_.name())?;
let base_url = if location == "global" {
format!("https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers")
} else {
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers")
};
let url = format!(
"{base_url}/google/models/{}:predict",
self_.model.real_name()
);
let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect();
let body = json!({
"instances": instances,
});
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(access_token);
Ok(request_data)
}
pub async fn gemini_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
gemini_extract_chat_completions_text(&data)
}
pub async fn gemini_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
let data: Value = res.json().await?;
catch_error(&data, status.as_u16())?;
} else {
let handle = |value: &str| -> Result<()> {
let data: Value = serde_json::from_str(value)?;
debug!("stream-data: {data}");
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for (i, part) in parts.iter().enumerate() {
if let Some(text) = part["text"].as_str() {
if i > 0 {
handler.text("\n\n")?;
}
handler.text(text)?;
} else if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
}
}
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked due to safety")
}
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body
.predictions
.into_iter()
.map(|v| v.embeddings.values)
.collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
predictions: Vec<EmbeddingsResBodyPrediction>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyPrediction {
embeddings: EmbeddingsResBodyPredictionEmbeddings,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyPredictionEmbeddings {
values: Vec<f32>,
}
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let mut text_parts = vec![];
let mut tool_calls = vec![];
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for part in parts {
if let Some(text) = part["text"].as_str() {
text_parts.push(text);
}
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
}
}
}
let text = text_parts.join("\n\n");
if text.is_empty() && tool_calls.is_empty() {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked due to safety")
} else {
bail!("Invalid response data: {data}");
}
}
let output = ChatCompletionsOutput {
text,
tool_calls,
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
output_tokens: data["usageMetadata"]["candidatesTokenCount"].as_u64(),
};
Ok(output)
}
pub fn gemini_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
functions,
stream: _,
} = data;
let system_message = extract_system_message(&mut messages);
let mut network_image_urls = vec![];
let contents: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
let Message { role, content } = message;
let role = match role {
MessageRole::User => "user",
_ => "model",
};
match content {
MessageContent::Text(text) => vec![json!({
"role": role,
"parts": [{ "text": text }]
})],
MessageContent::Array(list) => {
let parts: Vec<Value> = list
.into_iter()
.map(|item| match item {
MessageContentPart::Text { text } => json!({"text": text}),
MessageContentPart::ImageUrl { image_url: ImageUrl { url } } => {
if let Some((mime_type, data)) = url.strip_prefix("data:").and_then(|v| v.split_once(";base64,")) {
json!({ "inline_data": { "mime_type": mime_type, "data": data } })
} else {
network_image_urls.push(url.clone());
json!({ "url": url })
}
},
})
.collect();
vec![json!({ "role": role, "parts": parts })]
},
MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => {
let model_parts: Vec<Value> = tool_results.iter().map(|tool_result| {
json!({
"functionCall": {
"name": tool_result.call.name,
"args": tool_result.call.arguments,
}
})
}).collect();
let function_parts: Vec<Value> = tool_results.into_iter().map(|tool_result| {
json!({
"functionResponse": {
"name": tool_result.call.name,
"response": {
"name": tool_result.call.name,
"content": tool_result.output,
}
}
})
}).collect();
vec![
json!({ "role": "model", "parts": model_parts }),
json!({ "role": "function", "parts": function_parts }),
]
}
}
})
.collect();
if !network_image_urls.is_empty() {
bail!(
"The model does not support network images: {:?}",
network_image_urls
);
}
let mut body = json!({ "contents": contents, "generationConfig": {} });
if let Some(v) = system_message {
body["systemInstruction"] = json!({ "parts": [{"text": v }] });
}
if let Some(v) = model.max_tokens_param() {
body["generationConfig"]["maxOutputTokens"] = v.into();
}
if let Some(v) = temperature {
body["generationConfig"]["temperature"] = v.into();
}
if let Some(v) = top_p {
body["generationConfig"]["topP"] = v.into();
}
if let Some(functions) = functions {
// Gemini doesn't support functions with parameters that have empty properties, so we need to patch it.
let function_declarations: Vec<_> = functions
.into_iter()
.map(|function| {
if function.parameters.is_empty_properties() {
json!({
"name": function.name,
"description": function.description,
})
} else {
json!(function)
}
})
.collect();
body["tools"] = json!([{ "functionDeclarations": function_declarations }]);
}
Ok(body)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelCategory {
Gemini,
Claude,
Mistral,
}
impl FromStr for ModelCategory {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.starts_with("gemini") {
Ok(ModelCategory::Gemini)
} else if s.starts_with("claude") {
Ok(ModelCategory::Claude)
} else if s.starts_with("mistral") || s.starts_with("codestral") {
Ok(ModelCategory::Mistral)
} else {
unsupported_model!(s)
}
}
}
pub async fn prepare_gcloud_access_token(
client: &reqwest::Client,
client_name: &str,
adc_file: &Option<String>,
) -> Result<()> {
if !is_valid_access_token(client_name) {
let (token, expires_in) = fetch_access_token(client, adc_file)
.await
.with_context(|| "Failed to fetch access token")?;
let expires_at = Utc::now()
+ Duration::try_seconds(expires_in)
.ok_or_else(|| anyhow!("Failed to parse expires_in of access_token"))?;
set_access_token(client_name, token, expires_at.timestamp())
}
Ok(())
}
async fn fetch_access_token(
client: &reqwest::Client,
file: &Option<String>,
) -> Result<(String, i64)> {
let credentials = load_adc(file).await?;
let value: Value = client
.post("https://oauth2.googleapis.com/token")
.json(&credentials)
.send()
.await?
.json()
.await?;
if let (Some(access_token), Some(expires_in)) =
(value["access_token"].as_str(), value["expires_in"].as_i64())
{
Ok((access_token.to_string(), expires_in))
} else if let Some(err_msg) = value["error_description"].as_str() {
bail!("{err_msg}")
} else {
bail!("Invalid response data: {value}")
}
}
async fn load_adc(file: &Option<String>) -> Result<Value> {
let adc_file = file
.as_ref()
.map(PathBuf::from)
.or_else(default_adc_file)
.ok_or_else(|| anyhow!("No application_default_credentials.json"))?;
let data = tokio::fs::read_to_string(adc_file).await?;
let data: Value = serde_json::from_str(&data)?;
if let (Some(client_id), Some(client_secret), Some(refresh_token)) = (
data["client_id"].as_str(),
data["client_secret"].as_str(),
data["refresh_token"].as_str(),
) {
Ok(json!({
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}))
} else {
bail!("Invalid application_default_credentials.json")
}
}
#[cfg(not(windows))]
fn default_adc_file() -> Option<PathBuf> {
let mut path = dirs::home_dir()?;
path.push(".config");
path.push("gcloud");
path.push("application_default_credentials.json");
Some(path)
}
#[cfg(windows)]
fn default_adc_file() -> Option<PathBuf> {
let mut path = dirs::config_dir()?;
path.push("gcloud");
path.push("application_default_credentials.json");
Some(path)
}
fn strip_model_version(name: &str) -> &str {
match name.split_once('@') {
Some((v, _)) => v,
None => name,
}
}
+570
View File
@@ -0,0 +1,570 @@
use super::*;
use crate::{
client::Model,
function::{run_llm_function, Functions},
};
use anyhow::{Context, Result};
use inquire::{validator::Validation, Text};
use std::{fs::read_to_string, path::Path};
use serde::{Deserialize, Serialize};
const DEFAULT_AGENT_NAME: &str = "rag";
pub type AgentVariables = IndexMap<String, String>;
#[derive(Debug, Clone)]
pub struct Agent {
name: String,
config: AgentConfig,
shared_variables: AgentVariables,
session_variables: Option<AgentVariables>,
shared_dynamic_instructions: Option<String>,
session_dynamic_instructions: Option<String>,
functions: Functions,
rag: Option<Arc<Rag>>,
model: Model,
}
impl Agent {
pub async fn init(
config: &GlobalConfig,
name: &str,
abort_signal: AbortSignal,
) -> Result<Self> {
let agent_data_dir = Config::agent_data_dir(name);
let loaders = config.read().document_loaders.clone();
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME);
let config_path = Config::agent_config_file(name);
let mut agent_config = if config_path.exists() {
AgentConfig::load(&config_path)?
} else {
bail!("Agent config file not found at '{}'", config_path.display())
};
let mut functions = Functions::init_agent(name, &agent_config.global_tools)?;
config.write().functions.clear_mcp_meta_functions();
let mcp_servers =
(!agent_config.mcp_servers.is_empty()).then(|| agent_config.mcp_servers.join(","));
let registry = config
.write()
.mcp_registry
.take()
.expect("MCP registry should be initialized");
let new_mcp_registry =
McpRegistry::reinit(registry, mcp_servers, abort_signal.clone()).await?;
if !new_mcp_registry.is_empty() {
functions.append_mcp_meta_functions(new_mcp_registry.list_servers());
}
config.write().mcp_registry = Some(new_mcp_registry);
agent_config.replace_tools_placeholder(&functions);
agent_config.load_envs(&config.read());
let model = {
let config = config.read();
match agent_config.model_id.as_ref() {
Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?,
None => {
if agent_config.temperature.is_none() {
agent_config.temperature = config.temperature;
}
if agent_config.top_p.is_none() {
agent_config.top_p = config.top_p;
}
config.current_model().clone()
}
}
};
let rag = if rag_path.exists() {
Some(Arc::new(Rag::load(config, DEFAULT_AGENT_NAME, &rag_path)?))
} else if !agent_config.documents.is_empty() && !config.read().info_flag {
let mut ans = false;
if *IS_STDOUT_TERMINAL {
ans = Confirm::new("The agent has documents attached, init RAG?")
.with_default(true)
.prompt()?;
}
if ans {
let mut document_paths = vec![];
for path in &agent_config.documents {
if is_url(path) {
document_paths.push(path.to_string());
} else if is_loader_protocol(&loaders, path) {
let (protocol, document_path) = path
.split_once(':')
.with_context(|| "Invalid loader protocol path")?;
let resolved_path = resolve_home_dir(document_path);
let new_path = if Path::new(&resolved_path).is_relative() {
safe_join_path(&agent_data_dir, resolved_path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
} else {
PathBuf::from(&resolved_path)
};
document_paths.push(format!("{}:{}", protocol, new_path.display()));
} else if Path::new(&resolve_home_dir(path)).is_relative() {
let new_path = safe_join_path(&agent_data_dir, path)
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
document_paths.push(new_path.display().to_string())
} else {
document_paths.push(path.to_string())
}
}
let rag =
Rag::init(config, "rag", &rag_path, &document_paths, abort_signal).await?;
Some(Arc::new(rag))
} else {
None
}
} else {
None
};
Ok(Self {
name: name.to_string(),
config: agent_config,
shared_variables: Default::default(),
session_variables: None,
shared_dynamic_instructions: None,
session_dynamic_instructions: None,
functions,
rag,
model,
})
}
pub fn init_agent_variables(
agent_variables: &[AgentVariable],
no_interaction: bool,
) -> Result<AgentVariables> {
let mut output = IndexMap::new();
if agent_variables.is_empty() {
return Ok(output);
}
let mut printed = false;
let mut unset_variables = vec![];
for agent_variable in agent_variables {
let key = agent_variable.name.clone();
if let Some(value) = agent_variable.default.clone() {
output.insert(key, value);
continue;
}
if no_interaction {
continue;
}
if *IS_STDOUT_TERMINAL {
if !printed {
println!("⚙ Init agent variables...");
printed = true;
}
let value = Text::new(&format!(
"{} ({}):",
agent_variable.name, agent_variable.description
))
.with_validator(|input: &str| {
if input.trim().is_empty() {
Ok(Validation::Invalid("This field is required".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
output.insert(key, value);
} else {
unset_variables.push(agent_variable)
}
}
if !unset_variables.is_empty() {
bail!(
"The following agent variables are required:\n{}",
unset_variables
.iter()
.map(|v| format!(" - {}: {}", v.name, v.description))
.collect::<Vec<_>>()
.join("\n")
)
}
Ok(output)
}
pub fn export(&self) -> Result<String> {
let mut value = json!({});
value["name"] = json!(self.name());
let variables = self.variables();
if !variables.is_empty() {
value["variables"] = serde_json::to_value(variables)?;
}
value["config"] = json!(self.config);
let mut config = self.config.clone();
config.instructions = self.interpolated_instructions();
value["definition"] = json!(config);
value["data_dir"] = Config::agent_data_dir(&self.name)
.display()
.to_string()
.into();
value["config_file"] = Config::agent_config_file(&self.name)
.display()
.to_string()
.into();
let data = serde_yaml::to_string(&value)?;
Ok(data)
}
pub fn banner(&self) -> String {
self.config.banner()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn functions(&self) -> &Functions {
&self.functions
}
pub fn rag(&self) -> Option<Arc<Rag>> {
self.rag.clone()
}
pub fn conversation_starters(&self) -> &[String] {
&self.config.conversation_starters
}
pub fn interpolated_instructions(&self) -> String {
let mut output = self
.session_dynamic_instructions
.clone()
.or_else(|| self.shared_dynamic_instructions.clone())
.unwrap_or_else(|| self.config.instructions.clone());
for (k, v) in self.variables() {
output = output.replace(&format!("{{{{{k}}}}}"), v)
}
interpolate_variables(&mut output);
output
}
pub fn agent_prelude(&self) -> Option<&str> {
self.config.agent_prelude.as_deref()
}
pub fn variables(&self) -> &AgentVariables {
match &self.session_variables {
Some(variables) => variables,
None => &self.shared_variables,
}
}
pub fn variable_envs(&self) -> HashMap<String, String> {
self.variables()
.iter()
.map(|(k, v)| {
(
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
v.clone(),
)
})
.collect()
}
pub fn shared_variables(&self) -> &AgentVariables {
&self.shared_variables
}
pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) {
self.shared_variables = shared_variables;
}
pub fn set_session_variables(&mut self, session_variables: AgentVariables) {
self.session_variables = Some(session_variables);
}
pub fn defined_variables(&self) -> &[AgentVariable] {
&self.config.variables
}
pub fn exit_session(&mut self) {
self.session_variables = None;
self.session_dynamic_instructions = None;
}
pub fn is_dynamic_instructions(&self) -> bool {
self.config.dynamic_instructions
}
pub fn update_shared_dynamic_instructions(&mut self, force: bool) -> Result<()> {
if self.is_dynamic_instructions() && (force || self.shared_dynamic_instructions.is_none()) {
self.shared_dynamic_instructions = Some(self.run_instructions_fn()?);
}
Ok(())
}
pub fn update_session_dynamic_instructions(&mut self, value: Option<String>) -> Result<()> {
if self.is_dynamic_instructions() {
let value = match value {
Some(v) => v,
None => self.run_instructions_fn()?,
};
self.session_dynamic_instructions = Some(value);
}
Ok(())
}
fn run_instructions_fn(&self) -> Result<String> {
let value = run_llm_function(
self.name().to_string(),
vec!["_instructions".into(), "{}".into()],
self.variable_envs(),
)?;
match value {
Some(v) => Ok(v),
_ => bail!("No return value from '_instructions' function"),
}
}
}
impl RoleLike for Agent {
fn to_role(&self) -> Role {
let prompt = self.interpolated_instructions();
let mut role = Role::new("", &prompt);
role.sync(self);
role
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.config.temperature
}
fn top_p(&self) -> Option<f64> {
self.config.top_p
}
fn use_tools(&self) -> Option<String> {
self.config.global_tools.clone().join(",").into()
}
fn use_mcp_servers(&self) -> Option<String> {
self.config.mcp_servers.clone().join(",").into()
}
fn set_model(&mut self, model: Model) {
self.config.model_id = Some(model.id());
self.model = model;
}
fn set_temperature(&mut self, value: Option<f64>) {
self.config.temperature = value;
}
fn set_top_p(&mut self, value: Option<f64>) {
self.config.top_p = value;
}
fn set_use_tools(&mut self, value: Option<String>) {
match value {
Some(tools) => {
let tools = tools
.split(',')
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.collect::<Vec<_>>();
self.config.global_tools = tools;
}
None => {
self.config.global_tools.clear();
}
}
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
match value {
Some(servers) => {
let servers = servers
.split(',')
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.collect::<Vec<_>>();
self.config.mcp_servers = servers;
}
None => {
self.config.mcp_servers.clear();
}
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentConfig {
pub name: String,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_prelude: Option<String>,
#[serde(default)]
pub description: String,
#[serde(default)]
pub version: String,
#[serde(default)]
pub mcp_servers: Vec<String>,
#[serde(default)]
pub global_tools: Vec<String>,
#[serde(default)]
pub instructions: String,
#[serde(default)]
pub dynamic_instructions: bool,
#[serde(default)]
pub variables: Vec<AgentVariable>,
#[serde(default)]
pub conversation_starters: Vec<String>,
#[serde(default)]
pub documents: Vec<String>,
}
impl AgentConfig {
pub fn load(path: &Path) -> Result<Self> {
let contents = read_to_string(path)
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
let agent_config: Self = serde_yaml::from_str(&contents)
.with_context(|| format!("Failed to load agent config at '{}'", path.display()))?;
Ok(agent_config)
}
fn load_envs(&mut self, config: &Config) {
let name = &self.name;
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));
if self.agent_prelude.is_none() {
self.agent_prelude = config.agent_prelude.clone();
}
if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("temperature")) {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("top_p")) {
self.top_p = v;
}
if let Some(v) = read_env_value::<String>(&with_prefix("agent_prelude")) {
self.agent_prelude = v;
}
if let Ok(v) = env::var(with_prefix("variables")) {
if let Ok(v) = serde_json::from_str(&v) {
self.variables = v;
}
}
}
fn banner(&self) -> String {
let AgentConfig {
name,
description,
version,
conversation_starters,
..
} = self;
let starters = if conversation_starters.is_empty() {
String::new()
} else {
let starters = conversation_starters
.iter()
.map(|v| format!("- {v}"))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"
## Conversation Starters
{starters}"#
)
};
format!(
r#"# {name} {version}
{description}{starters}"#
)
}
fn replace_tools_placeholder(&mut self, functions: &Functions) {
let tools_placeholder: &str = "{{__tools__}}";
if self.instructions.contains(tools_placeholder) {
let tools = functions
.declarations()
.iter()
.enumerate()
.map(|(i, v)| {
let description = match v.description.split_once('\n') {
Some((v, _)) => v,
None => &v.description,
};
format!("{}. {}: {description}", i + 1, v.name)
})
.collect::<Vec<String>>()
.join("\n");
self.instructions = self.instructions.replace(tools_placeholder, &tools);
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentVariable {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_deserializing, default)]
pub value: String,
}
pub fn list_agents() -> Vec<String> {
let agents_file = Config::config_dir().join("agents.txt");
let contents = match read_to_string(agents_file) {
Ok(v) => v,
Err(_) => return vec![],
};
contents
.split('\n')
.filter_map(|line| {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
None
} else {
Some(line.to_string())
}
})
.collect()
}
pub fn complete_agent_variables(agent_name: &str) -> Vec<(String, Option<String>)> {
let config_path = Config::agent_config_file(agent_name);
if !config_path.exists() {
return vec![];
}
let Ok(config) = AgentConfig::load(&config_path) else {
return vec![];
};
config
.variables
.iter()
.map(|v| {
let description = match &v.default {
Some(default) => format!("{} [default: {default}]", v.description),
None => v.description.clone(),
};
(format!("{}=", v.name), Some(description))
})
.collect()
}
+545
View File
@@ -0,0 +1,545 @@
use super::*;
use crate::client::{
init_client, patch_messages, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
MessageContentPart, MessageContentToolCalls, MessageRole, Model,
};
use crate::function::ToolResult;
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};
use anyhow::{bail, Context, Result};
use indexmap::IndexSet;
use std::{collections::HashMap, fs::File, io::Read};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};
const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
const SUMMARY_MAX_WIDTH: usize = 80;
#[derive(Debug, Clone)]
pub struct Input {
config: GlobalConfig,
text: String,
raw: (String, Vec<String>),
patched_text: Option<String>,
last_reply: Option<String>,
continue_output: Option<String>,
regenerate: bool,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_calls: Option<MessageContentToolCalls>,
role: Role,
rag_name: Option<String>,
with_session: bool,
with_agent: bool,
}
impl Input {
pub fn from_str(config: &GlobalConfig, text: &str, role: Option<Role>) -> Self {
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
Self {
config: config.clone(),
text: text.to_string(),
raw: (text.to_string(), vec![]),
patched_text: None,
last_reply: None,
continue_output: None,
regenerate: false,
medias: Default::default(),
data_urls: Default::default(),
tool_calls: None,
role,
rag_name: None,
with_session,
with_agent,
}
}
pub async fn from_files(
config: &GlobalConfig,
raw_text: &str,
paths: Vec<String>,
role: Option<Role>,
) -> Result<Self> {
let loaders = config.read().document_loaders.clone();
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
resolve_paths(&loaders, paths)?;
let mut last_reply = None;
let (documents, medias, data_urls) = load_documents(
&loaders,
local_paths,
remote_urls,
external_cmds,
protocol_paths,
)
.await
.context("Failed to load files")?;
let mut texts = vec![];
if !raw_text.is_empty() {
texts.push(raw_text.to_string());
};
if with_last_reply {
if let Some(LastMessage { input, output, .. }) = config.read().last_message.as_ref() {
if !output.is_empty() {
last_reply = Some(output.clone())
} else if let Some(v) = input.last_reply.as_ref() {
last_reply = Some(v.clone());
}
if let Some(v) = last_reply.clone() {
texts.push(format!("\n{v}"));
}
}
if last_reply.is_none() && documents.is_empty() && medias.is_empty() {
bail!("No last reply found");
}
}
let documents_len = documents.len();
for (kind, path, contents) in documents {
if documents_len == 1 && raw_text.is_empty() {
texts.push(format!("\n{contents}"));
} else {
texts.push(format!(
"\n============ {kind}: {path} ============\n{contents}"
));
}
}
let (role, with_session, with_agent) = resolve_role(&config.read(), role);
Ok(Self {
config: config.clone(),
text: texts.join("\n"),
raw: (raw_text.to_string(), raw_paths),
patched_text: None,
last_reply,
continue_output: None,
regenerate: false,
medias,
data_urls,
tool_calls: Default::default(),
role,
rag_name: None,
with_session,
with_agent,
})
}
pub async fn from_files_with_spinner(
config: &GlobalConfig,
raw_text: &str,
paths: Vec<String>,
role: Option<Role>,
abort_signal: AbortSignal,
) -> Result<Self> {
abortable_run_with_spinner(
Input::from_files(config, raw_text, paths, role),
"Loading files",
abort_signal,
)
.await
}
pub fn is_empty(&self) -> bool {
self.text.is_empty() && self.medias.is_empty()
}
pub fn data_urls(&self) -> HashMap<String, String> {
self.data_urls.clone()
}
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
&self.tool_calls
}
pub fn text(&self) -> String {
match self.patched_text.clone() {
Some(text) => text,
None => self.text.clone(),
}
}
pub fn clear_patch(&mut self) {
self.patched_text = None;
}
pub fn set_text(&mut self, text: String) {
self.text = text;
}
pub fn stream(&self) -> bool {
self.config.read().stream && !self.role().model().no_stream()
}
pub fn continue_output(&self) -> Option<&str> {
self.continue_output.as_deref()
}
pub fn set_continue_output(&mut self, output: &str) {
let output = match &self.continue_output {
Some(v) => format!("{v}{output}"),
None => output.to_string(),
};
self.continue_output = Some(output);
}
pub fn regenerate(&self) -> bool {
self.regenerate
}
pub fn set_regenerate(&mut self) {
let role = self.config.read().extract_role();
if role.name() == self.role().name() {
self.role = role;
}
self.regenerate = true;
self.tool_calls = None;
}
pub async fn use_embeddings(&mut self, abort_signal: AbortSignal) -> Result<()> {
if self.text.is_empty() {
return Ok(());
}
let rag = self.config.read().rag.clone();
if let Some(rag) = rag {
let result = Config::search_rag(&self.config, &rag, &self.text, abort_signal).await?;
self.patched_text = Some(result);
self.rag_name = Some(rag.name().to_string());
}
Ok(())
}
pub fn rag_name(&self) -> Option<&str> {
self.rag_name.as_deref()
}
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
match self.tool_calls.as_mut() {
Some(exist_tool_results) => {
exist_tool_results.merge(tool_results, output);
}
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
}
self
}
pub fn create_client(&self) -> Result<Box<dyn Client>> {
init_client(&self.config, Some(self.role().model().clone()))
}
pub async fn fetch_chat_text(&self) -> Result<String> {
let client = self.create_client()?;
let text = client.chat_completions(self.clone()).await?.text;
let text = strip_think_tag(&text).to_string();
Ok(text)
}
pub fn prepare_completion_data(
&self,
model: &Model,
stream: bool,
) -> Result<ChatCompletionsData> {
let mut messages = self.build_messages()?;
patch_messages(&mut messages, model);
model.guard_max_input_tokens(&messages)?;
let (temperature, top_p) = (self.role().temperature(), self.role().top_p());
let functions = self.config.read().select_functions(self.role());
if let Some(vec) = &functions {
for def in vec {
debug!("Function definition: {:?}", def.name);
}
}
Ok(ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
})
}
pub fn build_messages(&self) -> Result<Vec<Message>> {
let mut messages = if let Some(session) = self.session(&self.config.read().session) {
session.build_messages(self)
} else {
self.role().build_messages(self)
};
if let Some(tool_calls) = &self.tool_calls {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolCalls(tool_calls.clone()),
))
}
Ok(messages)
}
pub fn echo_messages(&self) -> String {
if let Some(session) = self.session(&self.config.read().session) {
session.echo_messages(self)
} else {
self.role().echo_messages(self)
}
}
pub fn role(&self) -> &Role {
&self.role
}
pub fn session<'a>(&self, session: &'a Option<Session>) -> Option<&'a Session> {
if self.with_session {
session.as_ref()
} else {
None
}
}
pub fn session_mut<'a>(&self, session: &'a mut Option<Session>) -> Option<&'a mut Session> {
if self.with_session {
session.as_mut()
} else {
None
}
}
pub fn with_agent(&self) -> bool {
self.with_agent
}
pub fn summary(&self) -> String {
let text: String = self
.text
.trim()
.chars()
.map(|c| if c.is_control() { ' ' } else { c })
.collect();
if text.width_cjk() > SUMMARY_MAX_WIDTH {
let mut sum_width = 0;
let mut chars = vec![];
for c in text.chars() {
sum_width += c.width_cjk().unwrap_or(1);
if sum_width > SUMMARY_MAX_WIDTH - 3 {
chars.extend(['.', '.', '.']);
break;
}
chars.push(c);
}
chars.into_iter().collect()
} else {
text
}
}
pub fn raw(&self) -> String {
let (text, files) = &self.raw;
let mut segments = files.to_vec();
if !segments.is_empty() {
segments.insert(0, ".file".into());
}
if !text.is_empty() {
if !segments.is_empty() {
segments.push("--".into());
}
segments.push(text.clone());
}
segments.join(" ")
}
pub fn render(&self) -> String {
let text = self.text();
if self.medias.is_empty() {
return text;
}
let tail_text = if text.is_empty() {
String::new()
} else {
format!(" -- {text}")
};
let files: Vec<String> = self
.medias
.iter()
.cloned()
.map(|url| resolve_data_url(&self.data_urls, url))
.collect();
format!(".file {}{}", files.join(" "), tail_text)
}
pub fn message_content(&self) -> MessageContent {
if self.medias.is_empty() {
MessageContent::Text(self.text())
} else {
let mut list: Vec<MessageContentPart> = self
.medias
.iter()
.cloned()
.map(|url| MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
})
.collect();
if !self.text.is_empty() {
list.insert(0, MessageContentPart::Text { text: self.text() });
}
MessageContent::Array(list)
}
}
}
fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
match role {
Some(v) => (v, false, false),
None => (
config.extract_role(),
config.session.is_some(),
config.agent.is_some(),
),
}
}
type ResolvePathsOutput = (
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
bool,
);
fn resolve_paths(
loaders: &HashMap<String, String>,
paths: Vec<String>,
) -> Result<ResolvePathsOutput> {
let mut raw_paths = IndexSet::new();
let mut local_paths = IndexSet::new();
let mut remote_urls = IndexSet::new();
let mut external_cmds = IndexSet::new();
let mut protocol_paths = IndexSet::new();
let mut with_last_reply = false;
for path in paths {
if path == "%%" {
with_last_reply = true;
raw_paths.insert(path);
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
external_cmds.insert(path[1..path.len() - 1].to_string());
raw_paths.insert(path);
} else if is_url(&path) {
if path.strip_suffix("**").is_some() {
bail!("Invalid website '{path}'");
}
remote_urls.insert(path.clone());
raw_paths.insert(path);
} else if is_loader_protocol(loaders, &path) {
protocol_paths.insert(path.clone());
raw_paths.insert(path);
} else {
let resolved_path = resolve_home_dir(&path);
let absolute_path = to_absolute_path(&resolved_path)
.with_context(|| format!("Invalid path '{path}'"))?;
local_paths.insert(resolved_path);
raw_paths.insert(absolute_path);
}
}
Ok((
raw_paths.into_iter().collect(),
local_paths.into_iter().collect(),
remote_urls.into_iter().collect(),
external_cmds.into_iter().collect(),
protocol_paths.into_iter().collect(),
with_last_reply,
))
}
async fn load_documents(
loaders: &HashMap<String, String>,
local_paths: Vec<String>,
remote_urls: Vec<String>,
external_cmds: Vec<String>,
protocol_paths: Vec<String>,
) -> Result<(
Vec<(&'static str, String, String)>,
Vec<String>,
HashMap<String, String>,
)> {
let mut files = vec![];
let mut medias = vec![];
let mut data_urls = HashMap::new();
for cmd in external_cmds {
let output = duct::cmd(&SHELL.cmd, &[&SHELL.arg, &cmd])
.stderr_to_stdout()
.unchecked()
.read()
.unwrap_or_else(|err| err.to_string());
files.push(("CMD", cmd, output));
}
let local_files = expand_glob_paths(&local_paths, true).await?;
for file_path in local_files {
if is_image(&file_path) {
let contents = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media '{file_path}'"))?;
data_urls.insert(sha256(&contents), file_path);
medias.push(contents)
} else {
let document = load_file(loaders, &file_path)
.await
.with_context(|| format!("Unable to read file '{file_path}'"))?;
files.push(("FILE", file_path, document.contents));
}
}
for file_url in remote_urls {
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
.await
.with_context(|| format!("Failed to load url '{file_url}'"))?;
if extension == MEDIA_URL_EXTENSION {
data_urls.insert(sha256(&contents), file_url);
medias.push(contents)
} else {
files.push(("URL", file_url, contents));
}
}
for protocol_path in protocol_paths {
let documents = load_protocol_path(loaders, &protocol_path)
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
files.extend(
documents
.into_iter()
.map(|document| ("FROM", document.path, document.contents)),
);
}
Ok((files, medias, data_urls))
}
pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -> String {
if data_url.starts_with("data:") {
let hash = sha256(&data_url);
if let Some(path) = data_urls.get(&hash) {
return path.to_string();
}
data_url
} else {
data_url
}
}
fn is_image(path: &str) -> bool {
get_patch_extension(path)
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
.unwrap_or_default()
}
fn read_media_to_data_url(image_path: &str) -> Result<String> {
let extension = get_patch_extension(image_path).unwrap_or_default();
let mime_type = match extension.as_str() {
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"webp" => "image/webp",
"gif" => "image/gif",
_ => bail!("Unexpected media type"),
};
let mut file = File::open(image_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let encoded_image = base64_encode(buffer);
let data_url = format!("data:{mime_type};base64,{encoded_image}");
Ok(data_url)
}
+3034
View File
File diff suppressed because it is too large Load Diff
+416
View File
@@ -0,0 +1,416 @@
use super::*;
use crate::client::{Message, MessageContent, MessageRole, Model};
use anyhow::Result;
use fancy_regex::Regex;
use rust_embed::Embed;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::LazyLock;
pub const SHELL_ROLE: &str = "shell";
pub const EXPLAIN_SHELL_ROLE: &str = "explain-shell";
pub const CODE_ROLE: &str = "code";
pub const CREATE_TITLE_ROLE: &str = "create-title";
pub const INPUT_PLACEHOLDER: &str = "__INPUT__";
#[derive(Embed)]
#[folder = "assets/roles/"]
struct RolesAsset;
static RE_METADATA: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?s)-{3,}\s*(.*?)\s*-{3,}\s*(.*)").unwrap());
pub trait RoleLike {
fn to_role(&self) -> Role;
fn model(&self) -> &Model;
fn temperature(&self) -> Option<f64>;
fn top_p(&self) -> Option<f64>;
fn use_tools(&self) -> Option<String>;
fn use_mcp_servers(&self) -> Option<String>;
fn set_model(&mut self, model: Model);
fn set_temperature(&mut self, value: Option<f64>);
fn set_top_p(&mut self, value: Option<f64>);
fn set_use_tools(&mut self, value: Option<String>);
fn set_use_mcp_servers(&mut self, value: Option<String>);
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Role {
name: String,
#[serde(default)]
prompt: String,
#[serde(
rename(serialize = "model", deserialize = "model"),
skip_serializing_if = "Option::is_none"
)]
model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
use_mcp_servers: Option<String>,
#[serde(skip)]
model: Model,
}
impl Role {
pub fn new(name: &str, content: &str) -> Self {
let mut metadata = "";
let mut prompt = content.trim();
if let Ok(Some(caps)) = RE_METADATA.captures(content) {
if let (Some(metadata_value), Some(prompt_value)) = (caps.get(1), caps.get(2)) {
metadata = metadata_value.as_str().trim();
prompt = prompt_value.as_str().trim();
}
}
let mut prompt = prompt.to_string();
interpolate_variables(&mut prompt);
let mut role = Self {
name: name.to_string(),
prompt,
..Default::default()
};
if !metadata.is_empty() {
if let Ok(value) = serde_yaml::from_str::<Value>(metadata) {
if let Some(value) = value.as_object() {
for (key, value) in value {
match key.as_str() {
"model" => role.model_id = value.as_str().map(|v| v.to_string()),
"temperature" => role.temperature = value.as_f64(),
"top_p" => role.top_p = value.as_f64(),
"use_tools" => role.use_tools = value.as_str().map(|v| v.to_string()),
"use_mcp_servers" => {
role.use_mcp_servers = value.as_str().map(|v| v.to_string())
}
_ => (),
}
}
}
}
}
role
}
pub fn builtin(name: &str) -> Result<Self> {
let content = RolesAsset::get(&format!("{name}.md"))
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?;
let content = unsafe { std::str::from_utf8_unchecked(&content.data) };
Ok(Role::new(name, content))
}
pub fn list_builtin_role_names() -> Vec<String> {
RolesAsset::iter()
.filter_map(|v| v.strip_suffix(".md").map(|v| v.to_string()))
.collect()
}
pub fn list_builtin_roles() -> Vec<Self> {
RolesAsset::iter()
.filter_map(|v| Role::builtin(&v).ok())
.collect()
}
pub fn has_args(&self) -> bool {
self.name.contains('#')
}
pub fn export(&self) -> String {
let mut metadata = vec![];
if let Some(model) = self.model_id() {
metadata.push(format!("model: {model}"));
}
if let Some(temperature) = self.temperature() {
metadata.push(format!("temperature: {temperature}"));
}
if let Some(top_p) = self.top_p() {
metadata.push(format!("top_p: {top_p}"));
}
if let Some(use_tools) = self.use_tools() {
metadata.push(format!("use_tools: {use_tools}"));
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
metadata.push(format!("use_mcp_servers: {use_mcp_servers}"));
}
if metadata.is_empty() {
format!("{}\n", self.prompt)
} else if self.prompt.is_empty() {
format!("---\n{}\n---\n", metadata.join("\n"))
} else {
format!("---\n{}\n---\n\n{}\n", metadata.join("\n"), self.prompt)
}
}
pub fn save(&mut self, role_name: &str, role_path: &Path, is_repl: bool) -> Result<()> {
ensure_parent_exists(role_path)?;
let content = self.export();
std::fs::write(role_path, content).with_context(|| {
format!(
"Failed to write role {} to {}",
self.name,
role_path.display()
)
})?;
if is_repl {
println!("✓ Saved role to '{}'.", role_path.display());
}
if role_name != self.name {
self.name = role_name.to_string();
}
Ok(())
}
pub fn sync<T: RoleLike>(&mut self, role_like: &T) {
let model = role_like.model();
let temperature = role_like.temperature();
let top_p = role_like.top_p();
let use_tools = role_like.use_tools();
let use_mcp_servers = role_like.use_mcp_servers();
self.batch_set(model, temperature, top_p, use_tools, use_mcp_servers);
}
pub fn batch_set(
&mut self,
model: &Model,
temperature: Option<f64>,
top_p: Option<f64>,
use_tools: Option<String>,
use_mcp_servers: Option<String>,
) {
self.set_model(model.clone());
if temperature.is_some() {
self.set_temperature(temperature);
}
if top_p.is_some() {
self.set_top_p(top_p);
}
if use_tools.is_some() {
self.set_use_tools(use_tools);
}
if use_mcp_servers.is_some() {
self.set_use_mcp_servers(use_mcp_servers);
}
}
pub fn is_derived(&self) -> bool {
self.name.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn model_id(&self) -> Option<&str> {
self.model_id.as_deref()
}
pub fn prompt(&self) -> &str {
&self.prompt
}
pub fn is_empty_prompt(&self) -> bool {
self.prompt.is_empty()
}
pub fn is_embedded_prompt(&self) -> bool {
self.prompt.contains(INPUT_PLACEHOLDER)
}
pub fn echo_messages(&self, input: &Input) -> String {
let input_markdown = input.render();
if self.is_empty_prompt() {
input_markdown
} else if self.is_embedded_prompt() {
self.prompt.replace(INPUT_PLACEHOLDER, &input_markdown)
} else {
format!("{}\n\n{}", self.prompt, input_markdown)
}
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut content = input.message_content();
let mut messages = if self.is_empty_prompt() {
vec![Message::new(MessageRole::User, content)]
} else if self.is_embedded_prompt() {
content.merge_prompt(|v: &str| self.prompt.replace(INPUT_PLACEHOLDER, v));
vec![Message::new(MessageRole::User, content)]
} else {
let mut messages = vec![];
let (system, cases) = parse_structure_prompt(&self.prompt);
if !system.is_empty() {
messages.push(Message::new(
MessageRole::System,
MessageContent::Text(system.to_string()),
));
}
if !cases.is_empty() {
messages.extend(cases.into_iter().flat_map(|(i, o)| {
vec![
Message::new(MessageRole::User, MessageContent::Text(i.to_string())),
Message::new(MessageRole::Assistant, MessageContent::Text(o.to_string())),
]
}));
}
messages.push(Message::new(MessageRole::User, content));
messages
};
if let Some(text) = input.continue_output() {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(text.into()),
));
}
messages
}
}
impl RoleLike for Role {
fn to_role(&self) -> Role {
self.clone()
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
fn top_p(&self) -> Option<f64> {
self.top_p
}
fn use_tools(&self) -> Option<String> {
self.use_tools.clone()
}
fn use_mcp_servers(&self) -> Option<String> {
self.use_mcp_servers.clone()
}
fn set_model(&mut self, model: Model) {
if !self.model().id().is_empty() {
self.model_id = Some(model.id().to_string());
}
self.model = model;
}
fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
}
fn set_top_p(&mut self, value: Option<f64>) {
self.top_p = value;
}
fn set_use_tools(&mut self, value: Option<String>) {
self.use_tools = value;
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
self.use_mcp_servers = value;
}
}
fn parse_structure_prompt(prompt: &str) -> (&str, Vec<(&str, &str)>) {
let mut text = prompt;
let mut search_input = true;
let mut system = None;
let mut parts = vec![];
loop {
let search = if search_input {
"### INPUT:"
} else {
"### OUTPUT:"
};
match text.find(search) {
Some(idx) => {
if system.is_none() {
system = Some(&text[..idx])
} else {
parts.push(&text[..idx])
}
search_input = !search_input;
text = &text[(idx + search.len())..];
}
None => {
if !text.is_empty() {
if system.is_none() {
system = Some(text)
} else {
parts.push(text)
}
}
break;
}
}
}
let parts_len = parts.len();
if parts_len > 0 && parts_len % 2 == 0 {
let cases: Vec<(&str, &str)> = parts
.iter()
.step_by(2)
.zip(parts.iter().skip(1).step_by(2))
.map(|(i, o)| (i.trim(), o.trim()))
.collect();
let system = system.map(|v| v.trim()).unwrap_or_default();
return (system, cases);
}
(prompt, vec![])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_structure_prompt1() {
let prompt = r#"
System message
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("System message", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt2() {
let prompt = r#"
### INPUT:
Input 1
### OUTPUT:
Output 1
"#;
assert_eq!(
parse_structure_prompt(prompt),
("", vec![("Input 1", "Output 1")])
);
}
#[test]
fn test_parse_structure_prompt3() {
let prompt = r#"
System message
### INPUT:
Input 1
"#;
assert_eq!(parse_structure_prompt(prompt), (prompt, vec![]));
}
}
+659
View File
@@ -0,0 +1,659 @@
use super::input::*;
use super::*;
use crate::client::{Message, MessageContent, MessageRole};
use crate::render::MarkdownRender;
use anyhow::{bail, Context, Result};
use fancy_regex::Regex;
use inquire::{validator::Validation, Confirm, Text};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::fs::{read_to_string, write};
use std::path::Path;
use std::sync::LazyLock;
static RE_AUTONAME_PREFIX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d{8}T\d{6}-").unwrap());
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Session {
#[serde(rename(serialize = "model", deserialize = "model"))]
model_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
use_mcp_servers: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
save_session: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
compress_threshold: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
role_name: Option<String>,
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
agent_variables: AgentVariables,
#[serde(default, skip_serializing_if = "String::is_empty")]
agent_instructions: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
compressed_messages: Vec<Message>,
messages: Vec<Message>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
data_urls: HashMap<String, String>,
#[serde(skip)]
model: Model,
#[serde(skip)]
role_prompt: String,
#[serde(skip)]
name: String,
#[serde(skip)]
path: Option<String>,
#[serde(skip)]
dirty: bool,
#[serde(skip)]
save_session_this_time: bool,
#[serde(skip)]
compressing: bool,
#[serde(skip)]
autoname: Option<AutoName>,
#[serde(skip)]
tokens: usize,
}
impl Session {
pub fn new(config: &Config, name: &str) -> Self {
let role = config.extract_role();
let mut session = Self {
name: name.to_string(),
save_session: config.save_session,
..Default::default()
};
session.set_role(role);
session.dirty = false;
session
}
pub fn load(config: &Config, name: &str, path: &Path) -> Result<Self> {
let content = read_to_string(path)
.with_context(|| format!("Failed to load session {} at {}", name, path.display()))?;
let mut session: Self =
serde_yaml::from_str(&content).with_context(|| format!("Invalid session {name}"))?;
session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?;
if let Some(autoname) = name.strip_prefix("_/") {
session.name = TEMP_SESSION_NAME.to_string();
session.path = None;
if let Ok(true) = RE_AUTONAME_PREFIX.is_match(autoname) {
session.autoname = Some(AutoName::new(autoname[16..].to_string()));
}
} else {
session.name = name.to_string();
session.path = Some(path.display().to_string());
}
if let Some(role_name) = &session.role_name {
if let Ok(role) = config.retrieve_role(role_name) {
session.role_prompt = role.prompt().to_string();
}
}
session.update_tokens();
Ok(session)
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty() && self.compressed_messages.is_empty()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn role_name(&self) -> Option<&str> {
self.role_name.as_deref()
}
pub fn dirty(&self) -> bool {
self.dirty
}
pub fn save_session(&self) -> Option<bool> {
self.save_session
}
pub fn tokens(&self) -> usize {
self.tokens
}
pub fn update_tokens(&mut self) {
self.tokens = self.model().total_tokens(&self.messages);
}
pub fn has_user_messages(&self) -> bool {
self.messages.iter().any(|v| v.role.is_user())
}
pub fn user_messages_len(&self) -> usize {
self.messages.iter().filter(|v| v.role.is_user()).count()
}
pub fn export(&self) -> Result<String> {
let mut data = json!({
"path": self.path,
"model": self.model().id(),
});
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
if let Some(top_p) = self.top_p() {
data["top_p"] = top_p.into();
}
if let Some(use_tools) = self.use_tools() {
data["use_tools"] = use_tools.into();
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
data["use_mcp_servers"] = use_mcp_servers.into();
}
if let Some(save_session) = self.save_session() {
data["save_session"] = save_session.into();
}
let (tokens, percent) = self.tokens_usage();
data["total_tokens"] = tokens.into();
if let Some(max_input_tokens) = self.model().max_input_tokens() {
data["max_input_tokens"] = max_input_tokens.into();
}
if percent != 0.0 {
data["total/max"] = format!("{percent}%").into();
}
data["messages"] = json!(self.messages);
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about session '{}'", &self.name))?;
Ok(output)
}
pub fn render(
&self,
render: &mut MarkdownRender,
agent_info: &Option<(String, Vec<String>)>,
) -> Result<String> {
let mut items = vec![];
if let Some(path) = &self.path {
items.push(("path", path.to_string()));
}
if let Some(autoname) = self.autoname() {
items.push(("autoname", autoname.to_string()));
}
items.push(("model", self.model().id()));
if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
}
if let Some(top_p) = self.top_p() {
items.push(("top_p", top_p.to_string()));
}
if let Some(use_tools) = self.use_tools() {
items.push(("use_tools", use_tools));
}
if let Some(use_mcp_servers) = self.use_mcp_servers() {
items.push(("use_mcp_servers", use_mcp_servers));
}
if let Some(save_session) = self.save_session() {
items.push(("save_session", save_session.to_string()));
}
if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}
if let Some(max_input_tokens) = self.model().max_input_tokens() {
items.push(("max_input_tokens", max_input_tokens.to_string()));
}
let mut lines: Vec<String> = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect();
lines.push(String::new());
if !self.is_empty() {
let resolve_url_fn = |url: &str| resolve_data_url(&self.data_urls, url.to_string());
for message in &self.messages {
match message.role {
MessageRole::System => {
lines.push(
render
.render(&message.content.render_input(resolve_url_fn, agent_info)),
);
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
lines.push(render.render(text));
}
lines.push("".into());
}
MessageRole::User => {
lines.push(format!(
">> {}",
message.content.render_input(resolve_url_fn, agent_info)
));
}
MessageRole::Tool => {
lines.push(message.content.render_input(resolve_url_fn, agent_info));
}
}
}
}
Ok(lines.join("\n"))
}
pub fn tokens_usage(&self) -> (usize, f32) {
let tokens = self.tokens();
let max_input_tokens = self.model().max_input_tokens().unwrap_or_default();
let percent = if max_input_tokens == 0 {
0.0
} else {
let percent = tokens as f32 / max_input_tokens as f32 * 100.0;
(percent * 100.0).round() / 100.0
};
(tokens, percent)
}
pub fn set_role(&mut self, role: Role) {
self.model_id = role.model().id();
self.temperature = role.temperature();
self.top_p = role.top_p();
self.use_tools = role.use_tools();
self.use_mcp_servers = role.use_mcp_servers();
self.model = role.model().clone();
self.role_name = convert_option_string(role.name());
self.role_prompt = role.prompt().to_string();
self.dirty = true;
self.update_tokens();
}
pub fn clear_role(&mut self) {
self.role_name = None;
self.role_prompt.clear();
}
pub fn sync_agent(&mut self, agent: &Agent) {
self.role_name = None;
self.role_prompt = agent.interpolated_instructions();
self.agent_variables = agent.variables().clone();
self.agent_instructions = self.role_prompt.clone();
}
pub fn agent_variables(&self) -> &AgentVariables {
&self.agent_variables
}
pub fn agent_instructions(&self) -> &str {
&self.agent_instructions
}
pub fn set_save_session(&mut self, value: Option<bool>) {
if self.save_session != value {
self.save_session = value;
self.dirty = true;
}
}
pub fn set_save_session_this_time(&mut self) {
self.save_session_this_time = true;
}
pub fn set_compress_threshold(&mut self, value: Option<usize>) {
if self.compress_threshold != value {
self.compress_threshold = value;
self.dirty = true;
}
}
pub fn need_compress(&self, global_compress_threshold: usize) -> bool {
if self.compressing {
return false;
}
let threshold = self.compress_threshold.unwrap_or(global_compress_threshold);
if threshold < 1 {
return false;
}
self.tokens() > threshold
}
pub fn compressing(&self) -> bool {
self.compressing
}
pub fn set_compressing(&mut self, compressing: bool) {
self.compressing = compressing;
}
pub fn compress(&mut self, mut prompt: String) {
if let Some(system_prompt) = self.messages.first().and_then(|v| {
if MessageRole::System == v.role {
let content = v.content.to_text();
if !content.is_empty() {
return Some(content);
}
}
None
}) {
prompt = format!("{system_prompt}\n\n{prompt}",);
}
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message::new(
MessageRole::System,
MessageContent::Text(prompt),
));
self.dirty = true;
self.update_tokens();
}
pub fn need_autoname(&self) -> bool {
self.autoname.as_ref().map(|v| v.need()).unwrap_or_default()
}
pub fn set_autonaming(&mut self, naming: bool) {
if let Some(v) = self.autoname.as_mut() {
v.naming = naming;
}
}
pub fn chat_history_for_autonaming(&self) -> Option<String> {
self.autoname.as_ref().and_then(|v| v.chat_history.clone())
}
pub fn autoname(&self) -> Option<&str> {
self.autoname.as_ref().and_then(|v| v.name.as_deref())
}
pub fn set_autoname(&mut self, value: &str) {
let name = value
.chars()
.map(|v| if v.is_alphanumeric() { v } else { '-' })
.collect();
self.autoname = Some(AutoName::new(name));
}
pub fn exit(&mut self, session_dir: &Path, is_repl: bool) -> Result<()> {
let mut save_session = self.save_session();
if self.save_session_this_time {
save_session = Some(true);
}
if self.dirty && save_session != Some(false) {
let mut session_dir = session_dir.to_path_buf();
let mut session_name = self.name().to_string();
if save_session.is_none() {
if !is_repl {
return Ok(());
}
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
if !ans {
return Ok(());
}
if session_name == TEMP_SESSION_NAME {
session_name = Text::new("Session name:")
.with_validator(|input: &str| {
let input = input.trim();
if input.is_empty() {
Ok(Validation::Invalid("This name is required".into()))
} else if input == TEMP_SESSION_NAME {
Ok(Validation::Invalid("This name is reserved".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
} else if save_session == Some(true) && session_name == TEMP_SESSION_NAME {
session_dir = session_dir.join("_");
ensure_parent_exists(&session_dir).with_context(|| {
format!("Failed to create directory '{}'", session_dir.display())
})?;
let now = chrono::Local::now();
session_name = now.format("%Y%m%dT%H%M%S").to_string();
if let Some(autoname) = self.autoname() {
session_name = format!("{session_name}-{autoname}")
}
}
let session_path = session_dir.join(format!("{session_name}.yaml"));
self.save(&session_name, &session_path, is_repl)?;
}
Ok(())
}
pub fn save(&mut self, session_name: &str, session_path: &Path, is_repl: bool) -> Result<()> {
ensure_parent_exists(session_path)?;
self.path = Some(session_path.display().to_string());
let content = serde_yaml::to_string(&self)
.with_context(|| format!("Failed to serde session '{}'", self.name))?;
write(session_path, content).with_context(|| {
format!(
"Failed to write session '{}' to '{}'",
self.name,
session_path.display()
)
})?;
if is_repl {
println!("✓ Saved the session to '{}'.", session_path.display());
}
if self.name() != session_name {
self.name = session_name.to_string()
}
self.dirty = false;
Ok(())
}
pub fn guard_empty(&self) -> Result<()> {
if !self.is_empty() {
bail!("Cannot perform this operation because the session has messages, please `.empty session` first.");
}
Ok(())
}
pub fn add_message(&mut self, input: &Input, output: &str) -> Result<()> {
if input.continue_output().is_some() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = format!("{text}{output}");
}
}
} else if input.regenerate() {
if let Some(message) = self.messages.last_mut() {
if let MessageContent::Text(text) = &mut message.content {
*text = output.to_string();
}
}
} else {
if self.messages.is_empty() {
if self.name == TEMP_SESSION_NAME && self.save_session == Some(true) {
let raw_input = input.raw();
let chat_history = format!("USER: {raw_input}\nASSISTANT: {output}\n");
self.autoname = Some(AutoName::new_from_chat_history(chat_history));
}
self.messages.extend(input.role().build_messages(input));
} else {
self.messages
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
if let Some(tool_calls) = input.tool_calls() {
self.messages.push(Message::new(
MessageRole::Tool,
MessageContent::ToolCalls(tool_calls.clone()),
))
}
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
));
}
self.dirty = true;
self.update_tokens();
Ok(())
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.autoname = None;
self.dirty = true;
self.update_tokens();
}
pub fn echo_messages(&self, input: &Input) -> String {
let messages = self.build_messages(input);
serde_yaml::to_string(&messages).unwrap_or_else(|_| "Unable to echo message".into())
}
pub fn build_messages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
if input.continue_output().is_some() {
return messages;
} else if input.regenerate() {
while let Some(last) = messages.last() {
if !last.role.is_user() {
messages.pop();
} else {
break;
}
}
return messages;
}
let mut need_add_msg = true;
let len = messages.len();
if len == 0 {
messages = input.role().build_messages(input);
need_add_msg = false;
} else if len == 1 && self.compressed_messages.len() >= 2 {
if let Some(index) = self
.compressed_messages
.iter()
.rposition(|v| v.role == MessageRole::User)
{
messages.extend(self.compressed_messages[index..].to_vec());
}
}
if need_add_msg {
messages.push(Message::new(MessageRole::User, input.message_content()));
}
messages
}
}
impl RoleLike for Session {
fn to_role(&self) -> Role {
let role_name = self.role_name.as_deref().unwrap_or_default();
let mut role = Role::new(role_name, &self.role_prompt);
role.sync(self);
role
}
fn model(&self) -> &Model {
&self.model
}
fn temperature(&self) -> Option<f64> {
self.temperature
}
fn top_p(&self) -> Option<f64> {
self.top_p
}
fn use_tools(&self) -> Option<String> {
self.use_tools.clone()
}
fn use_mcp_servers(&self) -> Option<String> {
self.use_mcp_servers.clone()
}
fn set_model(&mut self, model: Model) {
if self.model().id() != model.id() {
self.model_id = model.id();
self.model = model;
self.dirty = true;
self.update_tokens();
}
}
fn set_temperature(&mut self, value: Option<f64>) {
if self.temperature != value {
self.temperature = value;
self.dirty = true;
}
}
fn set_top_p(&mut self, value: Option<f64>) {
if self.top_p != value {
self.top_p = value;
self.dirty = true;
}
}
fn set_use_tools(&mut self, value: Option<String>) {
if self.use_tools != value {
self.use_tools = value;
self.dirty = true;
}
}
fn set_use_mcp_servers(&mut self, value: Option<String>) {
if self.use_mcp_servers != value {
self.use_mcp_servers = value;
self.dirty = true;
}
}
}
#[derive(Debug, Clone, Default)]
struct AutoName {
naming: bool,
chat_history: Option<String>,
name: Option<String>,
}
impl AutoName {
pub fn new(name: String) -> Self {
Self {
name: Some(name),
..Default::default()
}
}
pub fn new_from_chat_history(chat_history: String) -> Self {
Self {
chat_history: Some(chat_history),
..Default::default()
}
}
pub fn need(&self) -> bool {
!self.naming && self.chat_history.is_some() && self.name.is_none()
}
}
+825
View File
@@ -0,0 +1,825 @@
use crate::{
config::{Agent, Config, GlobalConfig},
utils::*,
};
use crate::mcp::{MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX};
use crate::parsers::{bash, python};
use anyhow::{anyhow, bail, Context, Result};
use indexmap::IndexMap;
use indoc::formatdoc;
use rust_embed::Embed;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::ffi::OsStr;
use std::fs::File;
use std::io::Write;
use std::{
collections::{HashMap, HashSet},
env, fs, io,
path::{Path, PathBuf},
};
use strum_macros::AsRefStr;
#[derive(Embed)]
#[folder = "assets/functions/"]
struct FunctionAsset;
#[cfg(windows)]
const PATH_SEP: &str = ";";
#[cfg(not(windows))]
const PATH_SEP: &str = ":";
#[derive(AsRefStr)]
enum BinaryType {
Tool,
Agent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)]
enum Language {
Bash,
Python,
Javascript,
Unsupported,
}
impl From<&String> for Language {
fn from(s: &String) -> Self {
match s.to_lowercase().as_str() {
"sh" => Language::Bash,
"py" => Language::Python,
"js" => Language::Javascript,
_ => Language::Unsupported,
}
}
}
#[cfg_attr(not(windows), expect(dead_code))]
impl Language {
fn to_cmd(self) -> &'static str {
match self {
Language::Bash => "bash",
Language::Python => "python",
Language::Javascript => "node",
Language::Unsupported => "sh",
}
}
fn to_extension(self) -> &'static str {
match self {
Language::Bash => "sh",
Language::Python => "py",
Language::Javascript => "js",
_ => "sh",
}
}
}
pub async fn eval_tool_calls(
config: &GlobalConfig,
mut calls: Vec<ToolCall>,
) -> Result<Vec<ToolResult>> {
let mut output = vec![];
if calls.is_empty() {
return Ok(output);
}
calls = ToolCall::dedup(calls);
if calls.is_empty() {
bail!("The request was aborted because an infinite loop of function calls was detected.")
}
let mut is_all_null = true;
for call in calls {
let mut result = call.eval(config).await?;
if result.is_null() {
result = json!("DONE");
} else {
is_all_null = false;
}
output.push(ToolResult::new(call, result));
}
if is_all_null {
output = vec![];
}
Ok(output)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolResult {
pub call: ToolCall,
pub output: Value,
}
impl ToolResult {
pub fn new(call: ToolCall, output: Value) -> Self {
Self { call, output }
}
}
#[derive(Debug, Clone, Default)]
pub struct Functions {
declarations: Vec<FunctionDeclaration>,
}
impl Functions {
pub fn init() -> Result<Self> {
info!(
"Initializing global functions from {}",
Config::global_tools_file().display()
);
let declarations = Self {
declarations: Self::build_global_tool_declarations_from_path(
&Config::global_tools_file(),
)?,
};
info!(
"Building global function binaries in {}",
Config::functions_bin_dir().display()
);
Self::build_global_function_binaries_from_path(Config::global_tools_file())?;
Ok(declarations)
}
pub fn init_agent(name: &str, global_tools: &[String]) -> Result<Self> {
let global_tools_declarations = if !global_tools.is_empty() {
let enabled_tools = global_tools.join("\n");
info!("Loading global tools for agent: {name}: {enabled_tools}");
let tools_declarations = Self::build_global_tool_declarations(&enabled_tools)?;
info!(
"Building global function binaries required by agent: {name} in {}",
Config::functions_bin_dir().display()
);
Self::build_global_function_binaries(&enabled_tools)?;
tools_declarations
} else {
debug!("No global tools found for agent: {}", name);
Vec::new()
};
let agent_script_declarations = match Config::agent_functions_file(name) {
Ok(path) if path.exists() => {
info!(
"Loading functions script for agent: {name} from {}",
path.display()
);
let script_declarations = Self::generate_declarations(&path)?;
debug!("agent_declarations: {:#?}", script_declarations);
info!(
"Building function binary for agent: {name} in {}",
Config::agent_bin_dir(name).display()
);
Self::build_agent_tool_binaries(name)?;
script_declarations
}
_ => {
debug!("No functions script found for agent: {}", name);
Vec::new()
}
};
let declarations = [global_tools_declarations, agent_script_declarations].concat();
Ok(Self { declarations })
}
pub fn find(&self, name: &str) -> Option<&FunctionDeclaration> {
self.declarations.iter().find(|v| v.name == name)
}
pub fn contains(&self, name: &str) -> bool {
self.declarations.iter().any(|v| v.name == name)
}
pub fn declarations(&self) -> &[FunctionDeclaration] {
&self.declarations
}
pub fn is_empty(&self) -> bool {
self.declarations.is_empty()
}
pub fn has_mcp_functions(&self) -> bool {
self.declarations.iter().any(|d| {
d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX)
|| d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX)
})
}
pub fn clear_mcp_meta_functions(&mut self) {
self.declarations.retain(|d| {
!d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX)
&& !d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX)
});
}
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
let mut invoke_function_properties = IndexMap::new();
invoke_function_properties.insert(
"server".to_string(),
JsonSchema {
type_value: Some("string".to_string()),
..Default::default()
},
);
invoke_function_properties.insert(
"tool".to_string(),
JsonSchema {
type_value: Some("string".to_string()),
..Default::default()
},
);
invoke_function_properties.insert(
"arguments".to_string(),
JsonSchema {
type_value: Some("object".to_string()),
..Default::default()
},
);
for server in mcp_servers {
let invoke_function_name = format!("{}_{server}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX);
let invoke_function_declaration = FunctionDeclaration {
name: invoke_function_name.clone(),
description: formatdoc!(
r#"
Invoke the specified tool on the {server} MCP server. Always call {invoke_function_name} first to find the
correct names of tools before calling '{invoke_function_name}'.
"#
),
parameters: JsonSchema {
type_value: Some("object".to_string()),
properties: Some(invoke_function_properties.clone()),
required: Some(vec!["server".to_string(), "tool".to_string()]),
..Default::default()
},
agent: false,
};
let list_functions_declaration = FunctionDeclaration {
name: format!("{}_{}", MCP_LIST_META_FUNCTION_NAME_PREFIX, server),
description: format!("List all the available tools for the {server} MCP server"),
parameters: JsonSchema::default(),
agent: false,
};
self.declarations.push(invoke_function_declaration);
self.declarations.push(list_functions_declaration);
}
}
fn build_global_tool_declarations(enabled_tools: &str) -> Result<Vec<FunctionDeclaration>> {
let global_tools_directory = Config::global_tools_dir();
let mut function_declarations = Vec::new();
for line in enabled_tools.lines() {
if line.starts_with('#') {
continue;
}
let declaration = Self::generate_declarations(&global_tools_directory.join(line))?;
function_declarations.extend(declaration);
}
Ok(function_declarations)
}
fn build_global_tool_declarations_from_path(
tools_txt_path: &PathBuf,
) -> Result<Vec<FunctionDeclaration>> {
let enabled_tools = fs::read_to_string(tools_txt_path)
.with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?;
Self::build_global_tool_declarations(&enabled_tools)
}
fn generate_declarations(tools_file_path: &Path) -> Result<Vec<FunctionDeclaration>> {
info!(
"Loading tool definitions from {}",
tools_file_path.display()
);
let file_name = tools_file_path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| {
anyhow::format_err!("Unable to extract file name from path: {tools_file_path:?}")
})?;
match File::open(tools_file_path) {
Ok(tool_file) => {
let language = Language::from(
&tools_file_path
.extension()
.and_then(OsStr::to_str)
.map(|s| s.to_lowercase())
.ok_or_else(|| {
anyhow!("Unable to extract language from tool file: {file_name}")
})?,
);
match language {
Language::Bash => {
bash::generate_bash_declarations(tool_file, tools_file_path, file_name)
}
Language::Python => python::generate_python_declarations(
tool_file,
file_name,
tools_file_path.parent(),
),
Language::Unsupported => {
bail!("Unsupported tool file extension: {}", language.as_ref())
}
_ => bail!("Unsupported tool language: {}", language.as_ref()),
}
}
Err(err) if err.kind() == io::ErrorKind::NotFound => {
bail!(
"Tool definition file not found: {}",
tools_file_path.display()
);
}
Err(err) => bail!("Unable to open tool definition file. {}", err),
}
}
fn build_global_function_binaries(enabled_tools: &str) -> Result<()> {
let bin_dir = Config::functions_bin_dir();
if !bin_dir.exists() {
fs::create_dir_all(&bin_dir)?;
}
info!(
"Clearing existing function binaries in {}",
bin_dir.display()
);
clear_dir(&bin_dir)?;
for line in enabled_tools.lines() {
if line.starts_with('#') {
continue;
}
let language = Language::from(
&Path::new(line)
.extension()
.and_then(OsStr::to_str)
.map(|s| s.to_lowercase())
.ok_or_else(|| {
anyhow::format_err!("Unable to extract file extension from path: {line:?}")
})?,
);
let binary_name = Path::new(line)
.file_stem()
.and_then(OsStr::to_str)
.ok_or_else(|| {
anyhow::format_err!("Unable to extract file name from path: {line:?}")
})?;
if language == Language::Unsupported {
bail!("Unsupported tool file extension: {}", language.as_ref());
}
Self::build_binaries(binary_name, language, BinaryType::Tool)?;
}
Ok(())
}
fn build_global_function_binaries_from_path(tools_txt_path: PathBuf) -> Result<()> {
let enabled_tools = fs::read_to_string(&tools_txt_path)
.with_context(|| format!("failed to load functions at {}", tools_txt_path.display()))?;
Self::build_global_function_binaries(&enabled_tools)
}
fn build_agent_tool_binaries(name: &str) -> Result<()> {
let agent_bin_directory = Config::agent_bin_dir(name);
if !agent_bin_directory.exists() {
debug!(
"Creating agent bin directory: {}",
agent_bin_directory.display()
);
fs::create_dir_all(&agent_bin_directory)?;
} else {
debug!(
"Clearing existing agent bin directory: {}",
agent_bin_directory.display()
);
clear_dir(&agent_bin_directory)?;
}
let language = Language::from(
&Config::agent_functions_file(name)?
.extension()
.and_then(OsStr::to_str)
.map(|s| s.to_lowercase())
.ok_or_else(|| {
anyhow::format_err!("Unable to extract file extension from path: {name:?}")
})?,
);
if language == Language::Unsupported {
bail!("Unsupported tool file extension: {}", language.as_ref());
}
Self::build_binaries(name, language, BinaryType::Agent)
}
#[cfg(windows)]
fn build_binaries(
binary_name: &str,
language: Language,
binary_type: BinaryType,
) -> Result<()> {
use native::runtime;
let (binary_file, binary_script_file) = match binary_type {
BinaryType::Tool => (
Config::functions_bin_dir().join(format!("{binary_name}.cmd")),
Config::functions_bin_dir()
.join(format!("run-{binary_name}.{}", language.to_extension())),
),
BinaryType::Agent => (
Config::agent_bin_dir(binary_name).join(format!("{binary_name}.cmd")),
Config::agent_bin_dir(binary_name)
.join(format!("run-{binary_name}.{}", language.to_extension())),
),
};
info!(
"Building binary runner for function: {} ({})",
binary_name,
binary_script_file.display(),
);
let embedded_file = FunctionAsset::get(&format!(
"scripts/run-{}.{}",
binary_type.as_ref().to_lowercase(),
language.to_extension()
))
.ok_or_else(|| {
anyhow!(
"Failed to load embedded script for run-{}.{}",
binary_type.as_ref().to_lowercase(),
language.to_extension()
)
})?;
let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
let content = match binary_type {
BinaryType::Tool => content_template.replace("{function_name}", binary_name),
BinaryType::Agent => content_template.replace("{agent_name}", binary_name),
}
.replace("{config_dir}", &Config::config_dir().to_string_lossy());
if binary_script_file.exists() {
fs::remove_file(&binary_script_file)?;
}
let mut script_file = File::create(&binary_script_file)?;
script_file.write_all(content.as_bytes())?;
info!(
"Building binary for function: {} ({})",
binary_name,
binary_file.display()
);
let run = match language {
Language::Bash => {
let shell = runtime::bash_path().ok_or_else(|| anyhow!("Shell not found"))?;
format!("{shell} --noprofile --norc")
}
Language::Python if Path::new(".venv").exists() => {
let executable_path = env::current_dir()?
.join(".venv")
.join("Scripts")
.join("activate.bat");
let canonicalized_path = fs::canonicalize(&executable_path)?;
format!(
"call \"{}\" && {}",
canonicalized_path.to_string_lossy(),
language.to_cmd()
)
}
Language::Javascript => runtime::which(language.to_cmd())
.ok_or_else(|| anyhow!("Unable to find {} in PATH", language.to_cmd()))?,
_ => bail!("Unsupported language: {}", language.as_ref()),
};
let bin_dir = binary_file
.parent()
.expect("Failed to get parent directory of binary file")
.canonicalize()?
.to_string_lossy()
.into_owned();
let wrapper_binary = binary_script_file
.canonicalize()?
.to_string_lossy()
.into_owned();
let content = formatdoc!(
r#"
@echo off
setlocal
set "bin_dir={bin_dir}"
{run} "{wrapper_binary}" %*"#,
);
let mut file = File::create(&binary_file)?;
file.write_all(content.as_bytes())?;
Ok(())
}
#[cfg(not(windows))]
fn build_binaries(
binary_name: &str,
language: Language,
binary_type: BinaryType,
) -> Result<()> {
use std::os::unix::prelude::PermissionsExt;
let binary_file = match binary_type {
BinaryType::Tool => Config::functions_bin_dir().join(binary_name),
BinaryType::Agent => Config::agent_bin_dir(binary_name).join(binary_name),
};
info!(
"Building binary for function: {} ({})",
binary_name,
binary_file.display()
);
let embedded_file = FunctionAsset::get(&format!(
"scripts/run-{}.{}",
binary_type.as_ref().to_lowercase(),
language.to_extension()
))
.ok_or_else(|| {
anyhow!(
"Failed to load embedded script for run-{}.{}",
binary_type.as_ref().to_lowercase(),
language.to_extension()
)
})?;
let content_template = unsafe { std::str::from_utf8_unchecked(&embedded_file.data) };
let content = match binary_type {
BinaryType::Tool => content_template.replace("{function_name}", binary_name),
BinaryType::Agent => content_template.replace("{agent_name}", binary_name),
}
.replace("{config_dir}", &Config::config_dir().to_string_lossy());
if binary_file.exists() {
fs::remove_file(&binary_file)?;
}
let mut file = File::create(&binary_file)?;
file.write_all(content.as_bytes())?;
fs::set_permissions(&binary_file, fs::Permissions::from_mode(0o755))?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: JsonSchema,
#[serde(skip_serializing, default)]
pub agent: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct JsonSchema {
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub type_value: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<IndexMap<String, JsonSchema>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<JsonSchema>>,
#[serde(rename = "anyOf", skip_serializing_if = "Option::is_none")]
pub any_of: Option<Vec<JsonSchema>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
pub enum_value: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
impl JsonSchema {
pub fn is_empty_properties(&self) -> bool {
match &self.properties {
Some(v) => v.is_empty(),
None => true,
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct ToolCall {
pub name: String,
pub arguments: Value,
pub id: Option<String>,
}
type CallConfig = (String, String, Vec<String>, HashMap<String, String>);
impl ToolCall {
pub fn dedup(calls: Vec<Self>) -> Vec<Self> {
let mut new_calls = vec![];
let mut seen_ids = HashSet::new();
for call in calls.into_iter().rev() {
if let Some(id) = &call.id {
if !seen_ids.contains(id) {
seen_ids.insert(id.clone());
new_calls.push(call);
}
} else {
new_calls.push(call);
}
}
new_calls.reverse();
new_calls
}
pub fn new(name: String, arguments: Value, id: Option<String>) -> Self {
Self {
name,
arguments,
id,
}
}
pub async fn eval(&self, config: &GlobalConfig) -> Result<Value> {
let (call_name, cmd_name, mut cmd_args, envs) = match &config.read().agent {
Some(agent) => self.extract_call_config_from_agent(config, agent)?,
None => self.extract_call_config_from_config(config)?,
};
let json_data = if self.arguments.is_object() {
self.arguments.clone()
} else if let Some(arguments) = self.arguments.as_str() {
let arguments: Value = serde_json::from_str(arguments).map_err(|_| {
anyhow!("The call '{call_name}' has invalid arguments: {arguments}")
})?;
arguments
} else {
bail!(
"The call '{call_name}' has invalid arguments: {}",
self.arguments
);
};
cmd_args.push(json_data.to_string());
let prompt = format!("Call {cmd_name} {}", cmd_args.join(" "));
if *IS_STDOUT_TERMINAL {
println!("{}", dimmed_text(&prompt));
}
let output = match cmd_name.as_str() {
_ if cmd_name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) => {
let registry_arc = {
let cfg = config.read();
cfg.mcp_registry
.clone()
.with_context(|| "MCP is not configured")?
};
registry_arc.catalog().await?
}
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
let server = json_data
.get("server")
.ok_or_else(|| anyhow!("Missing 'server' in arguments"))?
.as_str()
.ok_or_else(|| anyhow!("Invalid 'server' in arguments"))?;
let tool = json_data
.get("tool")
.ok_or_else(|| anyhow!("Missing 'tool' in arguments"))?
.as_str()
.ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?;
let arguments = json_data
.get("arguments")
.cloned()
.unwrap_or_else(|| json!({}));
let registry_arc = {
let cfg = config.read();
cfg.mcp_registry
.clone()
.with_context(|| "MCP is not configured")?
};
let result = registry_arc.invoke(server, tool, arguments).await?;
serde_json::to_value(result)?
}
_ => match run_llm_function(cmd_name, cmd_args, envs)? {
Some(contents) => serde_json::from_str(&contents)
.ok()
.unwrap_or_else(|| json!({"output": contents})),
None => Value::Null,
},
};
Ok(output)
}
fn extract_call_config_from_agent(
&self,
config: &GlobalConfig,
agent: &Agent,
) -> Result<CallConfig> {
let function_name = self.name.clone();
match agent.functions().find(&function_name) {
Some(function) => {
let agent_name = agent.name().to_string();
if function.agent {
Ok((
format!("{agent_name}-{function_name}"),
agent_name,
vec![function_name],
agent.variable_envs(),
))
} else {
Ok((
function_name.clone(),
function_name,
vec![],
Default::default(),
))
}
}
None => self.extract_call_config_from_config(config),
}
}
fn extract_call_config_from_config(&self, config: &GlobalConfig) -> Result<CallConfig> {
let function_name = self.name.clone();
match config.read().functions.contains(&function_name) {
true => Ok((
function_name.clone(),
function_name,
vec![],
Default::default(),
)),
false => bail!("Unexpected call: {function_name} {}", self.arguments),
}
}
}
pub fn run_llm_function(
cmd_name: String,
cmd_args: Vec<String>,
mut envs: HashMap<String, String>,
) -> Result<Option<String>> {
let mut bin_dirs: Vec<PathBuf> = vec![];
if cmd_args.len() > 1 {
let dir = Config::agent_bin_dir(&cmd_name);
if dir.exists() {
bin_dirs.push(dir);
}
}
bin_dirs.push(Config::functions_bin_dir());
let current_path = env::var("PATH").context("No PATH environment variable")?;
let prepend_path = bin_dirs
.iter()
.map(|v| format!("{}{PATH_SEP}", v.display()))
.collect::<Vec<_>>()
.join("");
envs.insert("PATH".into(), format!("{prepend_path}{current_path}"));
let temp_file = temp_file("-eval-", "");
envs.insert("LLM_OUTPUT".into(), temp_file.display().to_string());
#[cfg(windows)]
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
if exit_code != 0 {
bail!("Tool call exited with {exit_code}");
}
let mut output = None;
if temp_file.exists() {
let contents =
fs::read_to_string(temp_file).context("Failed to retrieve tool call output")?;
if !contents.is_empty() {
debug!("Tool {cmd_name} output: {}", contents);
output = Some(contents);
}
};
Ok(output)
}
#[cfg(windows)]
fn polyfill_cmd_name<T: AsRef<Path>>(cmd_name: &str, bin_dir: &[T]) -> String {
let cmd_name = cmd_name.to_string();
if let Ok(exts) = env::var("PATHEXT") {
for name in exts.split(';').map(|ext| format!("{cmd_name}{ext}")) {
for dir in bin_dir {
let path = dir.as_ref().join(&name);
if path.exists() {
return name.to_string();
}
}
}
}
cmd_name
}
+496
View File
@@ -0,0 +1,496 @@
mod cli;
mod client;
mod config;
mod function;
mod rag;
mod render;
mod repl;
mod serve;
#[macro_use]
mod utils;
mod mcp;
mod parsers;
#[macro_use]
extern crate log;
use crate::cli::Cli;
use crate::client::{
call_chat_completions, call_chat_completions_streaming, list_models, ModelType,
};
use crate::config::{
ensure_parent_exists, list_agents, load_env_file, macro_execute, Agent, Config, GlobalConfig,
Input, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, TEMP_SESSION_NAME,
};
use crate::render::render_error;
use crate::repl::Repl;
use crate::utils::*;
use anyhow::{bail, Result};
use clap::{CommandFactory, Parser};
use clap_complete::CompleteEnv;
use inquire::Text;
use log::LevelFilter;
use log4rs::append::console::ConsoleAppender;
use log4rs::append::file::FileAppender;
use log4rs::config::{Appender, Logger, Root};
use log4rs::encode::pattern::PatternEncoder;
use parking_lot::RwLock;
use std::path::PathBuf;
use std::{env, mem, process, sync::Arc};
#[tokio::main]
async fn main() -> Result<()> {
load_env_file()?;
CompleteEnv::with_factory(Cli::command).complete();
let cli = Cli::parse();
if cli.tail_logs {
tail_logs(cli.disable_log_colors).await;
return Ok(());
}
let text = cli.text()?;
let working_mode = if cli.serve.is_some() {
WorkingMode::Serve
} else if text.is_none() && cli.file.is_empty() {
WorkingMode::Repl
} else {
WorkingMode::Cmd
};
let info_flag = cli.info
|| cli.sync_models
|| cli.list_models
|| cli.list_roles
|| cli.list_agents
|| cli.list_rags
|| cli.list_macros
|| cli.list_sessions;
let log_path = setup_logger(working_mode.is_serve())?;
let abort_signal = create_abort_signal();
let start_mcp_servers = cli.agent.is_none() && cli.role.is_none();
let config = Arc::new(RwLock::new(
Config::init(
working_mode,
info_flag,
start_mcp_servers,
log_path,
abort_signal.clone(),
)
.await?,
));
if let Err(err) = run(config, cli, text, abort_signal).await {
render_error(err);
process::exit(1);
}
Ok(())
}
async fn run(
config: GlobalConfig,
cli: Cli,
text: Option<String>,
abort_signal: AbortSignal,
) -> Result<()> {
if cli.sync_models {
let url = config.read().sync_models_url();
return Config::sync_models(&url, abort_signal.clone()).await;
}
if cli.list_models {
for model in list_models(&config.read(), ModelType::Chat) {
println!("{}", model.id());
}
return Ok(());
}
if cli.list_roles {
let roles = Config::list_roles(true).join("\n");
println!("{roles}");
return Ok(());
}
if cli.list_agents {
let agents = list_agents().join("\n");
println!("{agents}");
return Ok(());
}
if cli.list_rags {
let rags = Config::list_rags().join("\n");
println!("{rags}");
return Ok(());
}
if cli.list_macros {
let macros = Config::list_macros().join("\n");
println!("{macros}");
return Ok(());
}
if cli.dry_run {
config.write().dry_run = true;
}
if let Some(agent) = &cli.agent {
if cli.build_tools {
info!("Building tools for agent '{agent}'...");
Agent::init(&config, agent, abort_signal.clone()).await?;
return Ok(());
}
let session = cli.session.as_ref().map(|v| match v {
Some(v) => v.as_str(),
None => TEMP_SESSION_NAME,
});
if !cli.agent_variable.is_empty() {
config.write().agent_variables = Some(
cli.agent_variable
.chunks(2)
.map(|v| (v[0].to_string(), v[1].to_string()))
.collect(),
);
}
let ret = Config::use_agent(&config, agent, session, abort_signal.clone()).await;
config.write().agent_variables = None;
ret?;
} else {
if let Some(prompt) = &cli.prompt {
config.write().use_prompt(prompt)?;
} else if let Some(name) = &cli.role {
Config::use_role_safely(&config, name, abort_signal.clone()).await?;
} else if cli.execute {
Config::use_role_safely(&config, SHELL_ROLE, abort_signal.clone()).await?;
} else if cli.code {
Config::use_role_safely(&config, CODE_ROLE, abort_signal.clone()).await?;
}
if let Some(session) = &cli.session {
config
.write()
.use_session(session.as_ref().map(|v| v.as_str()))?;
}
if let Some(rag) = &cli.rag {
Config::use_rag(&config, Some(rag), abort_signal.clone()).await?;
}
}
if cli.build_tools {
return Ok(());
}
if cli.list_sessions {
let sessions = config.read().list_sessions().join("\n");
println!("{sessions}");
return Ok(());
}
if let Some(model_id) = &cli.model {
config.write().set_model(model_id)?;
}
if cli.no_stream {
config.write().stream = false;
}
if cli.empty_session {
config.write().empty_session()?;
}
if cli.save_session {
config.write().set_save_session_this_time()?;
}
if cli.info {
let info = config.read().info()?;
println!("{info}");
return Ok(());
}
if let Some(addr) = cli.serve {
return serve::run(config, addr).await;
}
let is_repl = config.read().working_mode.is_repl();
if cli.rebuild_rag {
Config::rebuild_rag(&config, abort_signal.clone()).await?;
if is_repl {
return Ok(());
}
}
if let Some(name) = &cli.macro_name {
macro_execute(&config, name, text.as_deref(), abort_signal.clone()).await?;
return Ok(());
}
if cli.execute && !is_repl {
let input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
shell_execute(&config, &SHELL, input, abort_signal.clone()).await?;
return Ok(());
}
apply_prelude_safely(&config, abort_signal.clone()).await?;
match is_repl {
false => {
let mut input = create_input(&config, text, &cli.file, abort_signal.clone()).await?;
input.use_embeddings(abort_signal.clone()).await?;
start_directive(&config, input, cli.code, abort_signal).await
}
true => {
if !*IS_STDOUT_TERMINAL {
bail!("No TTY for REPL")
}
start_interactive(&config).await
}
}
}
async fn apply_prelude_safely(config: &RwLock<Config>, abort_signal: AbortSignal) -> Result<()> {
let mut cfg = {
let mut guard = config.write();
mem::take(&mut *guard)
};
cfg.apply_prelude(abort_signal.clone()).await?;
{
let mut guard = config.write();
*guard = cfg;
}
Ok(())
}
#[async_recursion::async_recursion]
async fn start_directive(
config: &GlobalConfig,
input: Input,
code_mode: bool,
abort_signal: AbortSignal,
) -> Result<()> {
let client = input.create_client()?;
let extract_code = !*IS_STDOUT_TERMINAL && code_mode;
config.write().before_chat_completion(&input)?;
let (output, tool_results) = if !input.stream() || extract_code {
call_chat_completions(
&input,
true,
extract_code,
client.as_ref(),
abort_signal.clone(),
)
.await?
} else {
call_chat_completions_streaming(&input, client.as_ref(), abort_signal.clone()).await?
};
config
.write()
.after_chat_completion(&input, &output, &tool_results)?;
if !tool_results.is_empty() {
start_directive(
config,
input.merge_tool_results(output, tool_results),
code_mode,
abort_signal,
)
.await?;
}
config.write().exit_session()?;
Ok(())
}
async fn start_interactive(config: &GlobalConfig) -> Result<()> {
let mut repl: Repl = Repl::init(config)?;
repl.run().await
}
#[async_recursion::async_recursion]
async fn shell_execute(
config: &GlobalConfig,
shell: &Shell,
mut input: Input,
abort_signal: AbortSignal,
) -> Result<()> {
let client = input.create_client()?;
config.write().before_chat_completion(&input)?;
let (eval_str, _) =
call_chat_completions(&input, false, true, client.as_ref(), abort_signal.clone()).await?;
config
.write()
.after_chat_completion(&input, &eval_str, &[])?;
if eval_str.is_empty() {
bail!("No command generated");
}
if config.read().dry_run {
config.read().print_markdown(&eval_str)?;
return Ok(());
}
if *IS_STDOUT_TERMINAL {
let options = ["execute", "revise", "describe", "copy", "quit"];
let command = color_text(eval_str.trim(), nu_ansi_term::Color::Rgb(255, 165, 0));
let first_letter_color = nu_ansi_term::Color::Cyan;
let prompt_text = options
.iter()
.map(|v| format!("{}{}", color_text(&v[0..1], first_letter_color), &v[1..]))
.collect::<Vec<String>>()
.join(&dimmed_text(" | "));
loop {
println!("{command}");
let answer_char =
read_single_key(&['e', 'r', 'd', 'c', 'q'], 'e', &format!("{prompt_text}: "))?;
match answer_char {
'e' => {
debug!("{} {:?}", shell.cmd, &[&shell.arg, &eval_str]);
let code = run_command(&shell.cmd, &[&shell.arg, &eval_str], None)?;
if code == 0 && config.read().save_shell_history {
let _ = append_to_shell_history(&shell.name, &eval_str, code);
}
process::exit(code);
}
'r' => {
let revision = Text::new("Enter your revision:").prompt()?;
let text = format!("{}\n{revision}", input.text());
input.set_text(text);
return shell_execute(config, shell, input, abort_signal.clone()).await;
}
'd' => {
let role = config.read().retrieve_role(EXPLAIN_SHELL_ROLE)?;
let input = Input::from_str(config, &eval_str, Some(role));
if input.stream() {
call_chat_completions_streaming(
&input,
client.as_ref(),
abort_signal.clone(),
)
.await?;
} else {
call_chat_completions(
&input,
true,
false,
client.as_ref(),
abort_signal.clone(),
)
.await?;
}
println!();
continue;
}
'c' => {
set_text(&eval_str)?;
println!("{}", dimmed_text("✓ Copied the command."));
}
_ => {}
}
break;
}
} else {
println!("{eval_str}");
}
Ok(())
}
async fn create_input(
config: &GlobalConfig,
text: Option<String>,
file: &[String],
abort_signal: AbortSignal,
) -> Result<Input> {
let input = if file.is_empty() {
Input::from_str(config, &text.unwrap_or_default(), None)
} else {
Input::from_files_with_spinner(
config,
&text.unwrap_or_default(),
file.to_vec(),
None,
abort_signal,
)
.await?
};
if input.is_empty() {
bail!("No input");
}
Ok(input)
}
fn setup_logger(is_serve: bool) -> Result<Option<PathBuf>> {
let (log_level, log_path) = Config::log_config(is_serve)?;
if log_level == LevelFilter::Off {
return Ok(None);
}
let encoder = Box::new(PatternEncoder::new(
"{d(%Y-%m-%d %H:%M:%S%.3f)(utc)} <{i}> [{l}] {f}:{L} - {m}{n}",
));
let log_filter = match env::var(get_env_name("log_filter")) {
Ok(v) => Some(v),
Err(_) => match is_serve {
true => Some(format!("{}::serve", env!("CARGO_CRATE_NAME"))),
false => None,
},
};
match log_path.clone() {
None => {
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
log4rs::init_config(init_console_logger(log_level, log_filter, console_appender))?;
}
Some(path) => {
ensure_parent_exists(&path)?;
let file_appender = FileAppender::builder().encoder(encoder.clone()).build(path);
match file_appender {
Ok(appender) => {
log4rs::init_config(init_file_logger(log_level, log_filter, appender))?
}
Err(_) => {
let console_appender = ConsoleAppender::builder().encoder(encoder).build();
log4rs::init_config(init_console_logger(
log_level,
log_filter,
console_appender,
))?
}
};
}
}
Ok(log_path)
}
fn init_file_logger(
log_level: LevelFilter,
log_filter: Option<String>,
file_appender: FileAppender,
) -> log4rs::Config {
let root_log_level = if log_filter.is_some() {
LevelFilter::Off
} else {
log_level
};
let mut config_builder = log4rs::Config::builder()
.appender(Appender::builder().build("logfile", Box::new(file_appender)));
if let Some(filter) = log_filter {
config_builder = config_builder.logger(Logger::builder().build(filter, log_level));
}
config_builder
.build(Root::builder().appender("logfile").build(root_log_level))
.unwrap()
}
fn init_console_logger(
log_level: LevelFilter,
log_filter: Option<String>,
console_appender: ConsoleAppender,
) -> log4rs::Config {
let root_log_level = if log_filter.is_some() {
LevelFilter::Off
} else {
log_level
};
let mut config_builder = log4rs::Config::builder()
.appender(Appender::builder().build("console", Box::new(console_appender)));
if let Some(filter) = log_filter {
config_builder = config_builder.logger(Logger::builder().build(filter, log_level));
}
config_builder
.build(Root::builder().appender("console").build(root_log_level))
.unwrap()
}
+290
View File
@@ -0,0 +1,290 @@
use crate::config::Config;
use crate::utils::{abortable_run_with_spinner, AbortSignal};
use anyhow::{anyhow, Context, Result};
use futures_util::future::BoxFuture;
use futures_util::{stream, StreamExt, TryStreamExt};
use rmcp::model::{CallToolRequestParam, CallToolResult};
use rmcp::service::RunningService;
use rmcp::transport::TokioChildProcess;
use rmcp::{RoleClient, ServiceExt};
use serde::Deserialize;
use serde_json::{json, Value};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::fs::OpenOptions;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use tokio::process::Command;
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list";
type ConnectedServer = RunningService<RoleClient, ()>;
#[derive(Debug, Clone, Deserialize)]
struct McpServersConfig {
#[serde(rename = "mcpServers")]
mcp_servers: HashMap<String, McpServer>,
}
#[derive(Debug, Clone, Deserialize)]
struct McpServer {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, JsonField>>,
cwd: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum JsonField {
Str(String),
Bool(bool),
Int(i64),
}
#[derive(Debug, Clone, Default)]
pub struct McpRegistry {
log_path: Option<PathBuf>,
config: Option<McpServersConfig>,
servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>,
}
impl McpRegistry {
pub async fn init(
log_path: Option<PathBuf>,
start_mcp_servers: bool,
use_mcp_servers: Option<String>,
abort_signal: AbortSignal,
) -> Result<Self> {
let mut registry = Self {
log_path,
..Default::default()
};
if !Config::mcp_config_file().try_exists().with_context(|| {
format!(
"Failed to check MCP config file at {}",
Config::mcp_config_file().display()
)
})? {
debug!(
"MCP config file does not exist at {}, skipping MCP initialization",
Config::mcp_config_file().display()
);
return Ok(registry);
}
let err = || {
format!(
"Failed to load MCP config file at {}",
Config::mcp_config_file().display()
)
};
let content = tokio::fs::read_to_string(Config::mcp_config_file())
.await
.with_context(err)?;
let config: McpServersConfig = serde_json::from_str(&content).with_context(err)?;
registry.config = Some(config);
if start_mcp_servers {
abortable_run_with_spinner(
registry.start_select_mcp_servers(use_mcp_servers),
"Loading MCP servers",
abort_signal,
)
.await?;
}
Ok(registry)
}
pub async fn reinit(
registry: McpRegistry,
use_mcp_servers: Option<String>,
abort_signal: AbortSignal,
) -> Result<Self> {
debug!("Reinitializing MCP registry");
debug!("Stopping all MCP servers");
let mut new_registry = abortable_run_with_spinner(
registry.stop_all_servers(),
"Stopping MCP servers",
abort_signal.clone(),
)
.await?;
abortable_run_with_spinner(
new_registry.start_select_mcp_servers(use_mcp_servers),
"Loading MCP servers",
abort_signal,
)
.await?;
Ok(new_registry)
}
async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> {
if self.config.is_none() {
debug!("MCP config is not present; assuming MCP servers are disabled globally. skipping MCP initialization");
return Ok(());
}
debug!("Starting selected MCP servers: {:?}", use_mcp_servers);
if let Some(servers) = use_mcp_servers {
let config = self
.config
.as_ref()
.with_context(|| "MCP Config not defined. Cannot start servers")?;
let mcp_servers = config.mcp_servers.clone();
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
let server_ids: Vec<String> = if servers == "all" {
mcp_servers.into_keys().collect()
} else {
mcp_servers
.into_keys()
.filter(|id| enabled_servers.contains(id))
.collect()
};
let results: Vec<(String, Arc<_>)> = stream::iter(
server_ids
.into_iter()
.map(|id| async { self.start_server(id).await }),
)
.buffer_unordered(num_cpus::get())
.try_collect()
.await?;
self.servers = results.into_iter().collect();
}
Ok(())
}
async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> {
let server = self
.config
.as_ref()
.and_then(|c| c.mcp_servers.get(&id))
.with_context(|| format!("MCP server not found in config: {id}"))?;
let mut cmd = Command::new(&server.command);
if let Some(args) = &server.args {
cmd.args(args);
}
if let Some(env) = &server.env {
let env: HashMap<String, String> = env
.iter()
.map(|(k, v)| match v {
JsonField::Str(s) => (k.clone(), s.clone()),
JsonField::Bool(b) => (k.clone(), b.to_string()),
JsonField::Int(i) => (k.clone(), i.to_string()),
})
.collect();
cmd.envs(env);
}
if let Some(cwd) = &server.cwd {
cmd.current_dir(cwd);
}
let transport = if let Some(log_path) = self.log_path.as_ref() {
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
let log_file = OpenOptions::new()
.create(true)
.append(true)
.open(log_path)?;
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
transport
} else {
TokioChildProcess::new(cmd)?
};
let service = Arc::new(
().serve(transport)
.await
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
);
debug!(
"Available tools for MCP server {id}: {:?}",
service.list_tools(None).await?
);
info!("Started MCP server: {id}");
Ok((id.to_string(), service))
}
pub async fn stop_all_servers(mut self) -> Result<Self> {
for (id, server) in self.servers {
Arc::try_unwrap(server)
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
.cancel()
.await
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}");
}
self.servers = HashMap::new();
Ok(self)
}
pub fn list_servers(&self) -> Vec<String> {
self.servers.keys().cloned().collect()
}
pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> {
let servers: Vec<(String, Arc<ConnectedServer>)> = self
.servers
.iter()
.map(|(id, s)| (id.clone(), s.clone()))
.collect();
Box::pin(async move {
let mut out = Vec::with_capacity(servers.len());
for (id, server) in servers {
let tools = server.list_tools(None).await?;
let resources = server.list_resources(None).await.unwrap_or_default();
// TODO implement prompt sampling for MCP servers
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
out.push(json!({
"server": id,
"tools": tools,
"resources": resources,
}));
}
Ok(Value::Array(out))
})
}
pub fn invoke(
&self,
server: &str,
tool: &str,
arguments: Value,
) -> BoxFuture<'static, Result<CallToolResult>> {
let server = self
.servers
.get(server)
.cloned()
.with_context(|| format!("Invoked MCP server does not exist: {server}"));
let tool = tool.to_owned();
Box::pin(async move {
let server = server?;
let call_tool_request = CallToolRequestParam {
name: Cow::Owned(tool.to_owned()),
arguments: arguments.as_object().cloned(),
};
let result = server.call_tool(call_tool_request).await?;
Ok(result)
})
}
pub fn is_empty(&self) -> bool {
self.servers.is_empty()
}
}
+149
View File
@@ -0,0 +1,149 @@
use crate::function::{FunctionDeclaration, JsonSchema};
use anyhow::{bail, Context, Result};
use argc::{ChoiceValue, CommandValue, FlagOptionValue};
use indexmap::IndexMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::{env, fs};
pub fn generate_bash_declarations(
mut tool_file: File,
tools_file_path: &Path,
file_name: &str,
) -> Result<Vec<FunctionDeclaration>> {
let mut src = String::new();
tool_file
.read_to_string(&mut src)
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
debug!("Building script at '{tool_file:?}'");
let build_script = argc::build(
&src,
"",
env::var("TERM_WIDTH").ok().and_then(|v| v.parse().ok()),
)?;
fs::write(tools_file_path, &build_script)
.with_context(|| format!("Failed to write built script to '{tools_file_path:?}'"))?;
let command_value = argc::export(&build_script, file_name)
.with_context(|| format!("Failed to parse script at '{tool_file:?}'"))?;
if command_value.subcommands.is_empty() {
let function_declaration =
command_to_function_declaration(&command_value).ok_or_else(|| {
anyhow::format_err!("Tool definition missing or empty description: {file_name}")
})?;
Ok(vec![function_declaration])
} else {
let mut declarations = vec![];
for subcommand in &command_value.subcommands {
if subcommand.name.starts_with('_') && subcommand.name != "_instructions" {
continue;
}
if let Some(mut function_declaration) = command_to_function_declaration(subcommand) {
function_declaration.agent = true;
declarations.push(function_declaration);
} else {
bail!(
"Tool definition missing or empty description: {} {}",
file_name,
subcommand.name
);
}
}
Ok(declarations)
}
}
fn command_to_function_declaration(cmd: &CommandValue) -> Option<FunctionDeclaration> {
if cmd.describe.is_empty() {
return None;
}
Some(FunctionDeclaration {
name: underscore(&cmd.name),
description: cmd.describe.clone(),
parameters: parse_parameters_schema(&cmd.flag_options),
agent: false,
})
}
fn underscore(s: &str) -> String {
s.replace('-', "_")
}
fn schema_ty(t: &str) -> JsonSchema {
JsonSchema {
type_value: Some(t.to_string()),
description: None,
properties: None,
items: None,
any_of: None,
enum_value: None,
default: None,
required: None,
}
}
fn with_description(mut schema: JsonSchema, describe: &str) -> JsonSchema {
if !describe.is_empty() {
schema.description = Some(describe.to_string());
}
schema
}
fn with_enum(mut schema: JsonSchema, choice: &Option<ChoiceValue>) -> JsonSchema {
if let Some(ChoiceValue::Values(values)) = choice {
if !values.is_empty() {
schema.enum_value = Some(values.clone());
}
}
schema
}
fn parse_property(flag: &FlagOptionValue) -> JsonSchema {
let mut schema = if flag.flag {
schema_ty("boolean")
} else if flag.multiple_occurs {
let mut arr = schema_ty("array");
arr.items = Some(Box::new(schema_ty("string")));
arr
} else if flag.notations.first().map(|s| s.as_str()) == Some("INT") {
schema_ty("integer")
} else if flag.notations.first().map(|s| s.as_str()) == Some("NUM") {
schema_ty("number")
} else {
schema_ty("string")
};
schema = with_description(schema, &flag.describe);
schema = with_enum(schema, &flag.choice);
schema
}
fn parse_parameters_schema(flags: &[FlagOptionValue]) -> JsonSchema {
let filtered = flags.iter().filter(|f| f.id != "help" && f.id != "version");
let mut props: IndexMap<String, JsonSchema> = IndexMap::new();
let mut required: Vec<String> = Vec::new();
for f in filtered {
let key = underscore(&f.id);
if f.required {
required.push(key.clone());
}
props.insert(key, parse_property(f));
}
JsonSchema {
type_value: Some("object".to_string()),
description: None,
properties: Some(props),
items: None,
any_of: None,
enum_value: None,
default: None,
required: Some(required),
}
}
+2
View File
@@ -0,0 +1,2 @@
pub(crate) mod bash;
pub(crate) mod python;
+420
View File
@@ -0,0 +1,420 @@
use crate::function::{FunctionDeclaration, JsonSchema};
use anyhow::{bail, Context, Result};
use ast::{Stmt, StmtFunctionDef};
use indexmap::IndexMap;
use rustpython_ast::{Constant, Expr, UnaryOp};
use rustpython_parser::{ast, Mode};
use serde_json::Value;
use std::fs::File;
use std::io::Read;
use std::path::Path;
#[derive(Debug)]
struct Param {
name: String,
ty_hint: String,
required: bool,
default: Option<Value>,
doc_type: Option<String>,
doc_desc: Option<String>,
}
pub fn generate_python_declarations(
mut tool_file: File,
file_name: &str,
parent: Option<&Path>,
) -> Result<Vec<FunctionDeclaration>> {
let mut src = String::new();
tool_file
.read_to_string(&mut src)
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
let suite = parse_suite(&src, file_name)?;
let is_tool = parent
.and_then(|p| p.file_name())
.is_some_and(|n| n == "tools");
let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?;
if is_tool {
for d in &mut declarations {
d.agent = true;
}
}
Ok(declarations)
}
fn parse_suite(src: &str, filename: &str) -> Result<ast::Suite> {
let mod_ast =
rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?;
let suite = match mod_ast {
ast::Mod::Module(m) => m.body,
ast::Mod::Interactive(m) => m.body,
ast::Mod::Expression(_) => bail!("expected a module; got a single expression"),
_ => bail!("unexpected parse mode/AST variant"),
};
Ok(suite)
}
fn python_to_function_declarations(
file_name: &str,
module: &ast::Suite,
is_tool: bool,
) -> Result<Vec<FunctionDeclaration>> {
let mut out = Vec::new();
for stmt in module {
if let Stmt::FunctionDef(fd) = stmt {
let func_name = fd.name.to_string();
if func_name.starts_with('_') && func_name != "_instructions" {
continue;
}
if is_tool && func_name != "run" {
continue;
}
let description = get_docstring_from_body(&fd.body).unwrap_or_default();
let params = collect_params(fd);
let schema = build_parameters_schema(&params, &description);
let name = if is_tool && func_name == "run" {
underscore(file_name)
} else {
underscore(&func_name)
};
let desc_trim = description.trim().to_string();
if desc_trim.is_empty() {
bail!("Missing or empty description on function: {func_name}");
}
out.push(FunctionDeclaration {
name,
description: desc_trim,
parameters: schema,
agent: !is_tool,
});
}
}
Ok(out)
}
fn get_docstring_from_body(body: &[Stmt]) -> Option<String> {
let first = body.first()?;
if let Stmt::Expr(expr_stmt) = first {
if let Expr::Constant(constant) = &*expr_stmt.value {
if let Constant::Str(s) = &constant.value {
return Some(s.clone());
}
}
}
None
}
fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
let mut out = Vec::new();
for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) {
let name = a.def.arg.to_string();
let mut ty = get_arg_type(a.def.annotation.as_deref());
let mut required = a.default.is_none();
if ty.ends_with('?') {
ty.pop();
required = false;
}
let default = if a.default.is_some() {
Some(Value::Null)
} else {
None
};
out.push(Param {
name,
ty_hint: ty,
required,
default,
doc_type: None,
doc_desc: None,
});
}
for a in &fd.args.kwonlyargs {
let name = a.def.arg.to_string();
let mut ty = get_arg_type(a.def.annotation.as_deref());
let mut required = a.default.is_none();
if ty.ends_with('?') {
ty.pop();
required = false;
}
let default = if a.default.is_some() {
Some(Value::Null)
} else {
None
};
out.push(Param {
name,
ty_hint: ty,
required,
default,
doc_type: None,
doc_desc: None,
});
}
if let Some(vararg) = &fd.args.vararg {
let name = vararg.arg.to_string();
let inner = get_arg_type(vararg.annotation.as_deref());
let ty = if inner.is_empty() {
"list[str]".into()
} else {
format!("list[{inner}]")
};
out.push(Param {
name,
ty_hint: ty,
required: false,
default: None,
doc_type: None,
doc_desc: None,
});
}
if let Some(kwarg) = &fd.args.kwarg {
let name = kwarg.arg.to_string();
out.push(Param {
name,
ty_hint: "object".into(),
required: false,
default: None,
doc_type: None,
doc_desc: None,
});
}
if let Some(doc) = get_docstring_from_body(&fd.body) {
let meta = parse_docstring_args(&doc);
for p in &mut out {
if let Some((t, d)) = meta.get(&p.name) {
if !t.is_empty() {
p.doc_type = Some(t.clone());
}
if !d.is_empty() {
p.doc_desc = Some(d.clone());
}
if t.ends_with('?') {
p.required = false;
}
}
}
}
out
}
fn get_arg_type(annotation: Option<&Expr>) -> String {
match annotation {
None => "".to_string(),
Some(Expr::Name(n)) => n.id.to_string(),
Some(Expr::Subscript(sub)) => match &*sub.value {
Expr::Name(name) if &name.id == "Optional" => {
let inner = get_arg_type(Some(&sub.slice));
format!("{inner}?")
}
Expr::Name(name) if &name.id == "List" => {
let inner = get_arg_type(Some(&sub.slice));
format!("list[{inner}]")
}
Expr::Name(name) if &name.id == "Literal" => {
let vals = literal_members(&sub.slice);
format!("literal:{}", vals.join("|"))
}
_ => "any".to_string(),
},
_ => "any".to_string(),
}
}
fn expr_to_str(e: &Expr) -> String {
match e {
Expr::Constant(c) => match &c.value {
Constant::Str(s) => s.clone(),
Constant::Int(i) => i.to_string(),
Constant::Float(f) => f.to_string(),
Constant::Bool(b) => b.to_string(),
Constant::None => "None".to_string(),
Constant::Ellipsis => "...".to_string(),
Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(),
Constant::Complex { real, imag } => format!("{real}+{imag}j"),
_ => "any".to_string(),
},
Expr::Name(n) => n.id.to_string(),
Expr::UnaryOp(u) => {
if matches!(u.op, UnaryOp::USub) {
let inner = expr_to_str(&u.operand);
if inner.parse::<f64>().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) {
return format!("-{inner}");
}
}
"any".to_string()
}
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::<Vec<_>>().join(","),
_ => "any".to_string(),
}
}
fn literal_members(e: &Expr) -> Vec<String> {
match e {
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(),
_ => vec![expr_to_str(e)],
}
}
fn parse_docstring_args(doc: &str) -> IndexMap<String, (String, String)> {
let mut out = IndexMap::new();
let mut in_args = false;
for line in doc.lines() {
if !in_args {
if line.trim_start().starts_with("Args:") {
in_args = true;
}
continue;
}
if !(line.starts_with(' ') || line.starts_with('\t')) {
break;
}
let s = line.trim();
if let Some((left, desc)) = s.split_once(':') {
let left = left.trim();
let mut name = left.to_string();
let mut ty = String::new();
if let Some((n, t)) = left.split_once(' ') {
name = n.trim().to_string();
ty = t.trim().to_string();
if ty.starts_with('(') && ty.ends_with(')') {
let mut inner = ty[1..ty.len() - 1].to_string();
if inner.to_lowercase().contains("optional") && !inner.ends_with('?') {
inner.push('?');
}
ty = inner;
}
}
out.insert(name, (ty, desc.trim().to_string()));
}
}
out
}
fn underscore(s: &str) -> String {
s.chars()
.map(|c| {
if c.is_ascii_alphanumeric() {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect::<String>()
.split('_')
.filter(|t| !t.is_empty())
.collect::<Vec<_>>()
.join("_")
}
fn build_parameters_schema(params: &[Param], _description: &str) -> JsonSchema {
let mut props: IndexMap<String, JsonSchema> = IndexMap::new();
let mut req: Vec<String> = Vec::new();
for p in params {
let name = p.name.replace('-', "_");
let mut schema = JsonSchema::default();
let ty = if !p.ty_hint.is_empty() {
p.ty_hint.as_str()
} else if let Some(t) = &p.doc_type {
t.as_str()
} else {
"str"
};
if let Some(d) = &p.doc_desc {
if !d.is_empty() {
schema.description = Some(d.clone());
}
}
apply_type_to_schema(ty, &mut schema);
if p.default.is_none() && p.required {
req.push(name.clone());
}
props.insert(name, schema);
}
JsonSchema {
type_value: Some("object".into()),
description: None,
properties: Some(props),
items: None,
any_of: None,
enum_value: None,
default: None,
required: if req.is_empty() { None } else { Some(req) },
}
}
fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) {
let t = ty.trim_end_matches('?');
if let Some(rest) = t.strip_prefix("list[") {
s.type_value = Some("array".into());
let inner = rest.trim_end_matches(']');
let mut item = JsonSchema::default();
apply_type_to_schema(inner, &mut item);
if item.type_value.is_none() {
item.type_value = Some("string".into());
}
s.items = Some(Box::new(item));
return;
}
if let Some(rest) = t.strip_prefix("literal:") {
s.type_value = Some("string".into());
let vals = rest
.split('|')
.map(|x| x.trim().trim_matches('"').trim_matches('\'').to_string())
.collect::<Vec<_>>();
if !vals.is_empty() {
s.enum_value = Some(vals);
}
return;
}
s.type_value = Some(
match t {
"bool" => "boolean",
"int" => "integer",
"float" => "number",
"str" | "any" | "" => "string",
_ => "string",
}
.into(),
);
}
+1013
View File
File diff suppressed because it is too large Load Diff
+66
View File
@@ -0,0 +1,66 @@
use super::*;
use base64::{engine::general_purpose::STANDARD, Engine};
use serde::{de, Deserializer, Serializer};
pub fn serialize<S>(
vectors: &IndexMap<DocumentId, Vec<f32>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded_map: IndexMap<String, String> = vectors
.iter()
.map(|(id, vec)| {
let (h, l) = id.split();
let byte_slice = unsafe {
std::slice::from_raw_parts(vec.as_ptr() as *const u8, vec.len() * size_of::<f32>())
};
(format!("{h}-{l}"), STANDARD.encode(byte_slice))
})
.collect();
encoded_map.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<IndexMap<DocumentId, Vec<f32>>, D::Error>
where
D: Deserializer<'de>,
{
let encoded_map: IndexMap<String, String> =
IndexMap::<String, String>::deserialize(deserializer)?;
let mut decoded_map = IndexMap::new();
for (key, base64_str) in encoded_map {
let decoded_key: DocumentId = key
.split_once('-')
.and_then(|(h, l)| {
let h = h.parse::<usize>().ok()?;
let l = l.parse::<usize>().ok()?;
Some(DocumentId::new(h, l))
})
.ok_or_else(|| de::Error::custom(format!("Invalid key '{key}'")))?;
let decoded_data = STANDARD.decode(&base64_str).map_err(de::Error::custom)?;
if decoded_data.len() % size_of::<f32>() != 0 {
return Err(de::Error::custom(format!("Invalid vector at '{key}'")));
}
let num_f32s = decoded_data.len() / size_of::<f32>();
let mut vec_f32 = vec![0.0f32; num_f32s];
unsafe {
std::ptr::copy_nonoverlapping(
decoded_data.as_ptr(),
vec_f32.as_mut_ptr() as *mut u8,
decoded_data.len(),
);
}
decoded_map.insert(decoded_key, vec_f32);
}
Ok(decoded_map)
}
+235
View File
@@ -0,0 +1,235 @@
#[derive(PartialEq, Eq, Hash)]
pub enum Language {
Cpp,
Go,
Java,
Js,
Php,
Proto,
Python,
Rst,
Ruby,
Rust,
Scala,
Swift,
Markdown,
Latex,
Html,
Sol,
}
impl Language {
pub fn separators(&self) -> Vec<&str> {
match self {
Language::Cpp => vec![
"\nclass ",
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Go => vec![
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Java => vec![
"\nclass ",
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Js => vec![
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
"\n\n",
"\n",
" ",
"",
],
Language::Php => vec![
"\nfunction ",
"\nclass ",
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Proto => vec![
"\nmessage ",
"\nservice ",
"\nenum ",
"\noption ",
"\nimport ",
"\nsyntax ",
"\n\n",
"\n",
" ",
"",
],
Language::Python => vec!["\nclass ", "\ndef ", "\n\tdef ", "\n\n", "\n", " ", ""],
Language::Rst => vec![
"\n===\n", "\n---\n", "\n***\n", "\n.. ", "\n\n", "\n", " ", "",
],
Language::Ruby => vec![
"\ndef ",
"\nclass ",
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
"\n\n",
"\n",
" ",
"",
],
Language::Rust => vec![
"\nfn ", "\nconst ", "\nlet ", "\nif ", "\nwhile ", "\nfor ", "\nloop ",
"\nmatch ", "\nconst ", "\n\n", "\n", " ", "",
],
Language::Scala => vec![
"\nclass ",
"\nobject ",
"\ndef ",
"\nval ",
"\nvar ",
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Swift => vec![
"\nfunc ",
"\nclass ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
"\n\n",
"\n",
" ",
"",
],
Language::Markdown => vec![
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
"```\n\n",
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
"\n\n",
"\n",
" ",
"",
],
Language::Latex => vec![
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
"\n\\begin{align}",
"$$",
"$",
"\n\n",
"\n",
" ",
"",
],
Language::Html => vec![
"<body>", "<div>", "<p>", "<br>", "<li>", "<h1>", "<h2>", "<h3>", "<h4>", "<h5>",
"<h6>", "<span>", "<table>", "<tr>", "<td>", "<th>", "<ul>", "<ol>", "<header>",
"<footer>", "<nav>", "<head>", "<style>", "<script>", "<meta>", "<title>", " ", "",
],
Language::Sol => vec![
"\npragma ",
"\nusing ",
"\ncontract ",
"\ninterface ",
"\nlibrary ",
"\nconstructor ",
"\ntype ",
"\nfunction ",
"\nevent ",
"\nmodifier ",
"\nerror ",
"\nstruct ",
"\nenum ",
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo while ",
"\nassembly ",
"\n\n",
"\n",
" ",
"",
],
}
}
}
+475
View File
@@ -0,0 +1,475 @@
mod language;
pub use self::language::*;
use super::{DocumentMetadata, RagDocument};
pub const DEFAULT_SEPARATORS: [&str; 4] = ["\n\n", "\n", " ", ""];
pub fn get_separators(extension: &str) -> Vec<&'static str> {
match extension {
"c" | "cc" | "cpp" => Language::Cpp.separators(),
"go" => Language::Go.separators(),
"java" => Language::Java.separators(),
"js" | "mjs" | "cjs" => Language::Js.separators(),
"php" => Language::Php.separators(),
"proto" => Language::Proto.separators(),
"py" => Language::Python.separators(),
"rst" => Language::Rst.separators(),
"rb" => Language::Ruby.separators(),
"rs" => Language::Rust.separators(),
"scala" => Language::Scala.separators(),
"swift" => Language::Swift.separators(),
"md" | "mkd" => Language::Markdown.separators(),
"tex" => Language::Latex.separators(),
"htm" | "html" => Language::Html.separators(),
"sol" => Language::Sol.separators(),
_ => DEFAULT_SEPARATORS.to_vec(),
}
}
pub struct RecursiveCharacterTextSplitter {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub separators: Vec<String>,
pub length_function: Box<dyn Fn(&str) -> usize + Send + Sync>,
}
impl Default for RecursiveCharacterTextSplitter {
fn default() -> Self {
Self {
chunk_size: 1000,
chunk_overlap: 20,
separators: DEFAULT_SEPARATORS.iter().map(|v| v.to_string()).collect(),
length_function: Box::new(|text| text.len()),
}
}
}
impl RecursiveCharacterTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize, separators: &[&str]) -> Self {
Self::default()
.with_chunk_size(chunk_size)
.with_chunk_overlap(chunk_overlap)
.with_separators(separators)
}
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.chunk_overlap = chunk_overlap;
self
}
pub fn with_separators(mut self, separators: &[&str]) -> Self {
self.separators = separators.iter().map(|v| v.to_string()).collect();
self
}
pub fn split_documents(
&self,
documents: &[RagDocument],
chunk_header_options: &SplitterChunkHeaderOptions,
) -> Vec<RagDocument> {
let mut texts: Vec<String> = Vec::new();
let mut metadatas: Vec<DocumentMetadata> = Vec::new();
documents.iter().for_each(|d| {
if !d.page_content.is_empty() {
texts.push(d.page_content.clone());
metadatas.push(d.metadata.clone());
}
});
self.create_documents(&texts, &metadatas, chunk_header_options)
}
pub fn create_documents(
&self,
texts: &[String],
metadatas: &[DocumentMetadata],
chunk_header_options: &SplitterChunkHeaderOptions,
) -> Vec<RagDocument> {
let SplitterChunkHeaderOptions {
chunk_header,
chunk_overlap_header,
} = chunk_header_options;
let mut documents = Vec::new();
for (i, text) in texts.iter().enumerate() {
let mut prev_chunk: Option<String> = None;
let mut index_prev_chunk = -1;
for chunk in self.split_text(text) {
let mut page_content = chunk_header.clone();
let index_chunk = if index_prev_chunk < 0 {
text.find(&chunk).map(|i| i as i32).unwrap_or(-1)
} else {
match text[(index_prev_chunk as usize)..].chars().next() {
Some(c) => {
let offset = (index_prev_chunk as usize) + c.len_utf8();
text[offset..]
.find(&chunk)
.map(|i| (i + offset) as i32)
.unwrap_or(-1)
}
None => -1,
}
};
if prev_chunk.is_some() {
if let Some(chunk_overlap_header) = chunk_overlap_header {
page_content += chunk_overlap_header;
}
}
let metadata = metadatas[i].clone();
page_content += &chunk;
documents.push(RagDocument {
page_content,
metadata,
});
prev_chunk = Some(chunk);
index_prev_chunk = index_chunk;
}
}
documents
}
pub fn split_text(&self, text: &str) -> Vec<String> {
let keep_separator = self
.separators
.iter()
.any(|v| v.chars().any(|v| !v.is_whitespace()));
self.split_text_impl(text, &self.separators, keep_separator)
}
fn split_text_impl(
&self,
text: &str,
separators: &[String],
keep_separator: bool,
) -> Vec<String> {
let mut final_chunks = Vec::new();
let mut separator: String = separators.last().cloned().unwrap_or_default();
let mut new_separators: Vec<String> = vec![];
for (i, s) in separators.iter().enumerate() {
if s.is_empty() {
separator.clone_from(s);
break;
}
if text.contains(s) {
separator.clone_from(s);
new_separators = separators[i + 1..].to_vec();
break;
}
}
// Now that we have the separator, split the text
let splits = split_on_separator(text, &separator, keep_separator);
// Now go merging things, recursively splitting longer texts.
let mut good_splits = Vec::new();
let _separator = if keep_separator { "" } else { &separator };
for s in splits {
if (self.length_function)(s) < self.chunk_size {
good_splits.push(s.to_string());
} else {
if !good_splits.is_empty() {
let merged_text = self.merge_splits(&good_splits, _separator);
final_chunks.extend(merged_text);
good_splits.clear();
}
if new_separators.is_empty() {
final_chunks.push(s.to_string());
} else {
let other_info = self.split_text_impl(s, &new_separators, keep_separator);
final_chunks.extend(other_info);
}
}
}
if !good_splits.is_empty() {
let merged_text = self.merge_splits(&good_splits, _separator);
final_chunks.extend(merged_text);
}
final_chunks
}
fn merge_splits(&self, splits: &[String], separator: &str) -> Vec<String> {
let mut docs = Vec::new();
let mut current_doc = Vec::new();
let mut total = 0;
for d in splits {
let _len = (self.length_function)(d);
if total + _len + current_doc.len() * separator.len() > self.chunk_size {
if total > self.chunk_size {
// warn!("Warning: Created a chunk of size {}, which is longer than the specified {}", total, self.chunk_size);
}
if !current_doc.is_empty() {
let doc = self.join_docs(&current_doc, separator);
if let Some(doc) = doc {
docs.push(doc);
}
// Keep on popping if:
// - we have a larger chunk than in the chunk overlap
// - or if we still have any chunks and the length is long
while total > self.chunk_overlap
|| (total + _len + current_doc.len() * separator.len() > self.chunk_size
&& total > 0)
{
total -= (self.length_function)(&current_doc[0]);
current_doc.remove(0);
}
}
}
current_doc.push(d.to_string());
total += _len;
}
let doc = self.join_docs(&current_doc, separator);
if let Some(doc) = doc {
docs.push(doc);
}
docs
}
fn join_docs(&self, docs: &[String], separator: &str) -> Option<String> {
let text = docs.join(separator).trim().to_string();
if text.is_empty() {
None
} else {
Some(text)
}
}
}
pub struct SplitterChunkHeaderOptions {
pub chunk_header: String,
pub chunk_overlap_header: Option<String>,
}
impl Default for SplitterChunkHeaderOptions {
fn default() -> Self {
Self {
chunk_header: "".into(),
chunk_overlap_header: None,
}
}
}
impl SplitterChunkHeaderOptions {
// Set the value of chunk_header
#[allow(unused)]
pub fn with_chunk_header(mut self, header: &str) -> Self {
self.chunk_header = header.to_string();
self
}
// Set the value of chunk_overlap_header
#[allow(unused)]
pub fn with_chunk_overlap_header(mut self, overlap_header: &str) -> Self {
self.chunk_overlap_header = Some(overlap_header.to_string());
self
}
}
fn split_on_separator<'a>(text: &'a str, separator: &str, keep_separator: bool) -> Vec<&'a str> {
let splits: Vec<&str> = if !separator.is_empty() {
if keep_separator {
let mut splits = Vec::new();
let mut prev_idx = 0;
let sep_len = separator.len();
while let Some(idx) = text[prev_idx..].find(separator) {
splits.push(&text[prev_idx.saturating_sub(sep_len)..prev_idx + idx]);
prev_idx += idx + sep_len;
}
if prev_idx < text.len() {
splits.push(&text[prev_idx.saturating_sub(sep_len)..]);
}
splits
} else {
text.split(separator).collect()
}
} else {
text.split("").collect()
};
splits.into_iter().filter(|s| !s.is_empty()).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use indexmap::IndexMap;
use pretty_assertions::assert_eq;
use serde_json::{json, Value};
fn build_metadata(source: &str) -> Value {
json!({ "source": source })
}
#[test]
fn test_split_text() {
let splitter = RecursiveCharacterTextSplitter {
chunk_size: 7,
chunk_overlap: 3,
separators: vec![" ".into()],
..Default::default()
};
let output = splitter.split_text("foo bar baz 123");
assert_eq!(output, vec!["foo bar", "bar baz", "baz 123"]);
}
#[test]
fn test_create_document() {
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default();
let mut metadata1 = IndexMap::new();
metadata1.insert("source".into(), "1".into());
let mut metadata2 = IndexMap::new();
metadata2.insert("source".into(), "2".into());
let output = splitter.create_documents(
&["foo bar".into(), "baz".into()],
&[metadata1, metadata2],
&chunk_header_options,
);
let output = json!(output);
assert_eq!(
output,
json!([
{
"page_content": "foo",
"metadata": build_metadata("1"),
},
{
"page_content": "bar",
"metadata": build_metadata("1"),
},
{
"page_content": "baz",
"metadata": build_metadata("2"),
},
])
);
}
#[test]
fn test_chunk_header() {
let splitter = RecursiveCharacterTextSplitter::new(3, 0, &[" "]);
let chunk_header_options = SplitterChunkHeaderOptions::default()
.with_chunk_header("SOURCE NAME: testing\n-----\n")
.with_chunk_overlap_header("(cont'd) ");
let mut metadata1 = IndexMap::new();
metadata1.insert("source".into(), "1".into());
let mut metadata2 = IndexMap::new();
metadata2.insert("source".into(), "2".into());
let output = splitter.create_documents(
&["foo bar".into(), "baz".into()],
&[metadata1, metadata2],
&chunk_header_options,
);
let output = json!(output);
assert_eq!(
output,
json!([
{
"page_content": "SOURCE NAME: testing\n-----\nfoo",
"metadata": build_metadata("1"),
},
{
"page_content": "SOURCE NAME: testing\n-----\n(cont'd) bar",
"metadata": build_metadata("1"),
},
{
"page_content": "SOURCE NAME: testing\n-----\nbaz",
"metadata": build_metadata("2"),
},
])
);
}
#[test]
fn test_markdown_splitter() {
let text = r#"# 🦜️🔗 LangChain
⚡ Building applications with LLMs through composability ⚡
## Quick Install
```bash
# Hopefully this code block isn't split
pip install langchain
```
As an open source project in a rapidly developing field, we are extremely open to contributions."#;
let splitter =
RecursiveCharacterTextSplitter::new(100, 0, &Language::Markdown.separators());
let output = splitter.split_text(text);
let expected_output = vec![
"# 🦜️🔗 LangChain\n\n⚡ Building applications with LLMs through composability ⚡",
"## Quick Install\n\n```bash\n# Hopefully this code block isn't split\npip install langchain",
"```",
"As an open source project in a rapidly developing field, we are extremely open to contributions.",
];
assert_eq!(output, expected_output);
}
#[test]
fn test_html_splitter() {
let text = r#"<!DOCTYPE html>
<html>
<head>
<title>🦜️🔗 LangChain</title>
<style>
body {
font-family: Arial, sans-serif;
}
h1 {
color: darkblue;
}
</style>
</head>
<body>
<div>
<h1>🦜️🔗 LangChain</h1>
<p>⚡ Building applications with LLMs through composability ⚡</p>
</div>
<div>
As an open source project in a rapidly developing field, we are extremely open to contributions.
</div>
</body>
</html>"#;
let splitter = RecursiveCharacterTextSplitter::new(175, 20, &Language::Html.separators());
let output = splitter.split_text(text);
let expected_output = vec![
"<!DOCTYPE html>\n<html>",
"<head>\n <title>🦜️🔗 LangChain</title>",
r#"<style>
body {
font-family: Arial, sans-serif;
}
h1 {
color: darkblue;
}
</style>
</head>"#,
r#"<body>
<div>
<h1>🦜️🔗 LangChain</h1>
<p>⚡ Building applications with LLMs through composability ⚡</p>
</div>"#,
r#"<div>
As an open source project in a rapidly developing field, we are extremely open to contributions.
</div>
</body>
</html>"#,
];
assert_eq!(output, expected_output);
}
}
+393
View File
@@ -0,0 +1,393 @@
use crate::utils::decode_bin;
use ansi_colours::AsRGB;
use anyhow::{anyhow, Context, Result};
use crossterm::style::{Color, Stylize};
use crossterm::terminal;
use std::collections::HashMap;
use std::sync::LazyLock;
use syntect::highlighting::{Color as SyntectColor, FontStyle, Style, Theme};
use syntect::parsing::SyntaxSet;
use syntect::{easy::HighlightLines, parsing::SyntaxReference};
/// Comes from <https://github.com/sharkdp/bat/raw/5e77ca37e89c873e4490b42ff556370dc5c6ba4f/assets/syntaxes.bin>
const SYNTAXES: &[u8] = include_bytes!("../../assets/syntaxes.bin");
static LANG_MAPS: LazyLock<HashMap<String, String>> = LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("csharp".into(), "C#".into());
m.insert("php".into(), "PHP Source".into());
m
});
pub struct MarkdownRender {
options: RenderOptions,
syntax_set: SyntaxSet,
code_color: Option<Color>,
md_syntax: SyntaxReference,
code_syntax: Option<SyntaxReference>,
prev_line_type: LineType,
wrap_width: Option<u16>,
}
impl MarkdownRender {
pub fn init(options: RenderOptions) -> Result<Self> {
let syntax_set: SyntaxSet =
decode_bin(SYNTAXES).with_context(|| "MarkdownRender: invalid syntaxes binary")?;
let code_color = options
.theme
.as_ref()
.map(|theme| get_code_color(theme, options.truecolor));
let md_syntax = syntax_set.find_syntax_by_extension("md").unwrap().clone();
let line_type = LineType::Normal;
let wrap_width = match options.wrap.as_deref() {
None => None,
Some(value) => match terminal::size() {
Ok((columns, _)) => {
if value == "auto" {
Some(columns)
} else {
let value = value
.parse::<u16>()
.map_err(|_| anyhow!("Invalid wrap value"))?;
Some(columns.min(value))
}
}
Err(_) => None,
},
};
Ok(Self {
syntax_set,
code_color,
md_syntax,
code_syntax: None,
prev_line_type: line_type,
wrap_width,
options,
})
}
pub fn render(&mut self, text: &str) -> String {
text.split('\n')
.map(|line| self.render_line_mut(line))
.collect::<Vec<String>>()
.join("\n")
}
pub fn render_line(&self, line: &str) -> String {
let (_, code_syntax, is_code) = self.check_line(line);
if is_code {
self.highlight_code_line(line, &code_syntax)
} else {
self.highlight_line(line, &self.md_syntax, false)
}
}
fn render_line_mut(&mut self, line: &str) -> String {
let (line_type, code_syntax, is_code) = self.check_line(line);
let output = if is_code {
self.highlight_code_line(line, &code_syntax)
} else {
self.highlight_line(line, &self.md_syntax, false)
};
self.prev_line_type = line_type;
self.code_syntax = code_syntax;
output
}
fn check_line(&self, line: &str) -> (LineType, Option<SyntaxReference>, bool) {
let mut line_type = self.prev_line_type;
let mut code_syntax = self.code_syntax.clone();
let mut is_code = false;
if let Some(lang) = detect_code_block(line) {
match line_type {
LineType::Normal | LineType::CodeEnd => {
line_type = LineType::CodeBegin;
code_syntax = if lang.is_empty() {
None
} else {
self.find_syntax(&lang).cloned()
};
}
LineType::CodeBegin | LineType::CodeInner => {
line_type = LineType::CodeEnd;
code_syntax = None;
}
}
} else {
match line_type {
LineType::Normal => {}
LineType::CodeEnd => {
line_type = LineType::Normal;
}
LineType::CodeBegin => {
if code_syntax.is_none() {
if let Some(syntax) = self.syntax_set.find_syntax_by_first_line(line) {
code_syntax = Some(syntax.clone());
}
}
line_type = LineType::CodeInner;
is_code = true;
}
LineType::CodeInner => {
is_code = true;
}
}
}
(line_type, code_syntax, is_code)
}
fn highlight_line(&self, line: &str, syntax: &SyntaxReference, is_code: bool) -> String {
let ws: String = line.chars().take_while(|c| c.is_whitespace()).collect();
let trimmed_line: &str = &line[ws.len()..];
let mut line_highlighted = None;
if let Some(theme) = &self.options.theme {
let mut highlighter = HighlightLines::new(syntax, theme);
if let Ok(ranges) = highlighter.highlight_line(trimmed_line, &self.syntax_set) {
line_highlighted = Some(format!(
"{ws}{}",
as_terminal_escaped(&ranges, self.options.truecolor)
))
}
}
let line = line_highlighted.unwrap_or_else(|| line.into());
self.wrap_line(line, is_code)
}
fn highlight_code_line(&self, line: &str, code_syntax: &Option<SyntaxReference>) -> String {
if let Some(syntax) = code_syntax {
self.highlight_line(line, syntax, true)
} else {
let line = match self.code_color {
Some(color) => line.with(color).to_string(),
None => line.to_string(),
};
self.wrap_line(line, true)
}
}
fn wrap_line(&self, line: String, is_code: bool) -> String {
if let Some(width) = self.wrap_width {
if is_code && !self.options.wrap_code {
return line;
}
wrap(&line, width as usize)
} else {
line
}
}
fn find_syntax(&self, lang: &str) -> Option<&SyntaxReference> {
if let Some(new_lang) = LANG_MAPS.get(&lang.to_ascii_lowercase()) {
self.syntax_set.find_syntax_by_name(new_lang)
} else {
self.syntax_set
.find_syntax_by_token(lang)
.or_else(|| self.syntax_set.find_syntax_by_extension(lang))
}
}
}
fn wrap(text: &str, width: usize) -> String {
let indent: usize = text.chars().take_while(|c| *c == ' ').count();
let wrap_options = textwrap::Options::new(width)
.wrap_algorithm(textwrap::WrapAlgorithm::FirstFit)
.initial_indent(&text[0..indent]);
textwrap::wrap(&text[indent..], wrap_options).join("\n")
}
#[derive(Debug, Clone, Default)]
pub struct RenderOptions {
pub theme: Option<Theme>,
pub wrap: Option<String>,
pub wrap_code: bool,
pub truecolor: bool,
}
impl RenderOptions {
pub(crate) fn new(
theme: Option<Theme>,
wrap: Option<String>,
wrap_code: bool,
truecolor: bool,
) -> Self {
Self {
theme,
wrap,
wrap_code,
truecolor,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LineType {
Normal,
CodeBegin,
CodeInner,
CodeEnd,
}
fn as_terminal_escaped(ranges: &[(Style, &str)], truecolor: bool) -> String {
let mut output = String::new();
for (style, text) in ranges {
let fg = blend_fg_color(style.foreground, style.background);
let mut text = text.with(convert_color(fg, truecolor));
if style.font_style.contains(FontStyle::BOLD) {
text = text.bold();
}
if style.font_style.contains(FontStyle::UNDERLINE) {
text = text.underlined();
}
output.push_str(&text.to_string());
}
output
}
fn convert_color(c: SyntectColor, truecolor: bool) -> Color {
if truecolor {
Color::Rgb {
r: c.r,
g: c.g,
b: c.b,
}
} else {
let value = (c.r, c.g, c.b).to_ansi256();
// lower contrast
let value = match value {
7 | 15 | 231 | 252..=255 => 252,
_ => value,
};
Color::AnsiValue(value)
}
}
fn blend_fg_color(fg: SyntectColor, bg: SyntectColor) -> SyntectColor {
if fg.a == 0xff {
return fg;
}
let ratio = u32::from(fg.a);
let r = (u32::from(fg.r) * ratio + u32::from(bg.r) * (255 - ratio)) / 255;
let g = (u32::from(fg.g) * ratio + u32::from(bg.g) * (255 - ratio)) / 255;
let b = (u32::from(fg.b) * ratio + u32::from(bg.b) * (255 - ratio)) / 255;
SyntectColor {
r: u8::try_from(r).unwrap_or(u8::MAX),
g: u8::try_from(g).unwrap_or(u8::MAX),
b: u8::try_from(b).unwrap_or(u8::MAX),
a: 255,
}
}
fn detect_code_block(line: &str) -> Option<String> {
let line = line.trim_start();
if !line.starts_with("```") {
return None;
}
let lang = line
.chars()
.skip(3)
.take_while(|v| !v.is_whitespace())
.collect();
Some(lang)
}
fn get_code_color(theme: &Theme, truecolor: bool) -> Color {
let scope = theme.scopes.iter().find(|v| {
v.scope
.selectors
.iter()
.any(|v| v.path.scopes.iter().any(|v| v.to_string() == "string"))
});
scope
.and_then(|v| v.style.foreground)
.map_or_else(|| Color::Yellow, |c| convert_color(c, truecolor))
}
#[cfg(test)]
mod tests {
use super::*;
const TEXT: &str = r#"
To unzip a file in Rust, you can use the `zip` crate. Here's an example code that shows how to unzip a file:
```rust
use std::fs::File;
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn std::error::Error>> {
todo!()
}
```
"#;
const TEXT_NO_WRAP_CODE: &str = r#"
To unzip a file in Rust, you can use the `zip` crate. Here's an example code
that shows how to unzip a file:
```rust
use std::fs::File;
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn std::error::Error>> {
todo!()
}
```
"#;
const TEXT_WRAP_ALL: &str = r#"
To unzip a file in Rust, you can use the `zip` crate. Here's an example code
that shows how to unzip a file:
```rust
use std::fs::File;
fn unzip_file(path: &str, output_dir: &str) -> Result<(), Box<dyn
std::error::Error>> {
todo!()
}
```
"#;
#[test]
fn test_render() {
let options = RenderOptions::default();
let render = MarkdownRender::init(options).unwrap();
assert!(render.find_syntax("csharp").is_some());
}
#[test]
fn no_theme() {
let options = RenderOptions::default();
let mut render = MarkdownRender::init(options).unwrap();
let output = render.render(TEXT);
assert_eq!(TEXT, output);
}
#[test]
fn no_wrap_code() {
let options = RenderOptions::default();
let mut render = MarkdownRender::init(options).unwrap();
render.wrap_width = Some(80);
let output = render.render(TEXT);
assert_eq!(TEXT_NO_WRAP_CODE, output);
}
#[test]
fn wrap_all() {
let options = RenderOptions {
wrap_code: true,
..Default::default()
};
let mut render = MarkdownRender::init(options).unwrap();
render.wrap_width = Some(80);
let output = render.render(TEXT);
assert_eq!(TEXT_WRAP_ALL, output);
}
#[test]
fn test_detect_code_block() {
assert_eq!(detect_code_block("```rust"), Some("rust".into()));
assert_eq!(detect_code_block("```c++"), Some("c++".into()));
assert_eq!(detect_code_block(" ```rust"), Some("rust".into()));
assert_eq!(detect_code_block("```"), Some("".into()));
assert_eq!(detect_code_block("``rust"), None);
}
}
+30
View File
@@ -0,0 +1,30 @@
mod markdown;
mod stream;
pub use self::markdown::{MarkdownRender, RenderOptions};
use self::stream::{markdown_stream, raw_stream};
use crate::utils::{error_text, pretty_error, AbortSignal, IS_STDOUT_TERMINAL};
use crate::{client::SseEvent, config::GlobalConfig};
use anyhow::Result;
use tokio::sync::mpsc::UnboundedReceiver;
pub async fn render_stream(
rx: UnboundedReceiver<SseEvent>,
config: &GlobalConfig,
abort_signal: AbortSignal,
) -> Result<()> {
let ret = if *IS_STDOUT_TERMINAL && config.read().highlight {
let render_options = config.read().render_options()?;
let mut render = MarkdownRender::init(render_options)?;
markdown_stream(rx, &mut render, &abort_signal).await
} else {
raw_stream(rx, &abort_signal).await
};
ret.map_err(|err| err.context("Failed to reader stream"))
}
pub fn render_error(err: anyhow::Error) {
eprintln!("{}", error_text(&pretty_error(&err)));
}
+217
View File
@@ -0,0 +1,217 @@
use super::{MarkdownRender, SseEvent};
use crate::utils::{poll_abort_signal, spawn_spinner, AbortSignal};
use anyhow::Result;
use crossterm::{
cursor, queue, style,
terminal::{self, disable_raw_mode, enable_raw_mode},
};
use std::{
io::{stdout, Stdout, Write},
time::Duration,
};
use textwrap::core::display_width;
use tokio::sync::mpsc::UnboundedReceiver;
pub async fn markdown_stream(
rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort_signal: &AbortSignal,
) -> Result<()> {
enable_raw_mode()?;
let mut stdout = stdout();
let ret = markdown_stream_inner(rx, render, abort_signal, &mut stdout).await;
disable_raw_mode()?;
if ret.is_err() {
println!();
}
ret
}
pub async fn raw_stream(
mut rx: UnboundedReceiver<SseEvent>,
abort_signal: &AbortSignal,
) -> Result<()> {
let mut spinner = Some(spawn_spinner("Generating"));
loop {
if abort_signal.aborted() {
break;
}
if let Some(evt) = rx.recv().await {
if let Some(spinner) = spinner.take() {
spinner.stop();
}
match evt {
SseEvent::Text(text) => {
print!("{text}");
stdout().flush()?;
}
SseEvent::Done => {
break;
}
}
}
}
if let Some(spinner) = spinner.take() {
spinner.stop();
}
Ok(())
}
async fn markdown_stream_inner(
mut rx: UnboundedReceiver<SseEvent>,
render: &mut MarkdownRender,
abort_signal: &AbortSignal,
writer: &mut Stdout,
) -> Result<()> {
let mut buffer = String::new();
let mut buffer_rows = 1;
let columns = terminal::size()?.0;
let mut spinner = Some(spawn_spinner("Generating"));
'outer: loop {
if abort_signal.aborted() {
break;
}
for reply_event in gather_events(&mut rx).await {
if let Some(spinner) = spinner.take() {
spinner.stop();
}
match reply_event {
SseEvent::Text(mut text) => {
// tab width hacking
text = text.replace('\t', " ");
let mut attempts = 0;
let (col, mut row) = loop {
match cursor::position() {
Ok(pos) => break pos,
Err(_) if attempts < 3 => attempts += 1,
Err(e) => return Err(e.into()),
}
};
// Fix unexpected duplicate lines on kitty
if col == 0 && row > 0 && display_width(&buffer) == columns as usize {
row -= 1;
}
if row + 1 >= buffer_rows {
queue!(writer, cursor::MoveTo(0, row + 1 - buffer_rows),)?;
} else {
let scroll_rows = buffer_rows - row - 1;
queue!(
writer,
terminal::ScrollUp(scroll_rows),
cursor::MoveTo(0, 0),
)?;
}
// No guarantee that text returned by render will not be re-layouted, so it is better to clear it.
queue!(writer, terminal::Clear(terminal::ClearType::FromCursorDown))?;
if text.contains('\n') {
let text = format!("{buffer}{text}");
let (head, tail) = split_line_tail(&text);
let output = render.render(head);
print_block(writer, &output, columns)?;
buffer = tail.to_string();
} else {
buffer = format!("{buffer}{text}");
}
let output = render.render_line(&buffer);
if output.contains('\n') {
let (head, tail) = split_line_tail(&output);
buffer_rows = print_block(writer, head, columns)?;
queue!(writer, style::Print(&tail),)?;
// No guarantee the buffer width of the buffer will not exceed the number of columns.
// So we calculate the number of rows needed, rather than setting it directly to 1.
buffer_rows += need_rows(tail, columns);
} else {
queue!(writer, style::Print(&output))?;
buffer_rows = need_rows(&output, columns);
}
writer.flush()?;
}
SseEvent::Done => {
break 'outer;
}
}
}
if poll_abort_signal(abort_signal)? {
break;
}
}
if let Some(spinner) = spinner.take() {
spinner.stop();
}
Ok(())
}
async fn gather_events(rx: &mut UnboundedReceiver<SseEvent>) -> Vec<SseEvent> {
let mut texts = vec![];
let mut done = false;
tokio::select! {
_ = async {
while let Some(reply_event) = rx.recv().await {
match reply_event {
SseEvent::Text(v) => texts.push(v),
SseEvent::Done => {
done = true;
break;
}
}
}
} => {}
_ = tokio::time::sleep(Duration::from_millis(50)) => {}
}
let mut events = vec![];
if !texts.is_empty() {
events.push(SseEvent::Text(texts.join("")))
}
if done {
events.push(SseEvent::Done)
}
events
}
fn print_block(writer: &mut Stdout, text: &str, columns: u16) -> Result<u16> {
let mut num = 0;
for line in text.split('\n') {
queue!(
writer,
style::Print(line),
style::Print("\n"),
cursor::MoveLeft(columns),
)?;
num += 1;
}
Ok(num)
}
fn split_line_tail(text: &str) -> (&str, &str) {
if let Some((head, tail)) = text.rsplit_once('\n') {
(head, tail)
} else {
("", text)
}
}
fn need_rows(text: &str, columns: u16) -> u16 {
let buffer_width = display_width(text).max(1) as u16;
buffer_width.div_ceil(columns)
}
+159
View File
@@ -0,0 +1,159 @@
use super::{ReplCommand, REPL_COMMANDS};
use crate::{config::GlobalConfig, utils::fuzzy_filter};
use reedline::{Completer, Span, Suggestion};
use std::collections::HashMap;
impl Completer for ReplCompleter {
fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
let mut suggestions = vec![];
let line = &line[0..pos];
let mut parts = split_line(line);
if parts.is_empty() {
return suggestions;
}
if parts[0].0 == r#":::"# {
parts.remove(0);
}
let parts_len = parts.len();
if parts_len == 0 {
return suggestions;
}
let (cmd, cmd_start) = parts[0];
if !cmd.starts_with('.') {
return suggestions;
}
let state = self.config.read().state();
let command_filter = parts
.iter()
.take(2)
.map(|(v, _)| *v)
.collect::<Vec<&str>>()
.join(" ");
let commands: Vec<_> = self
.commands
.iter()
.filter(|cmd| {
cmd.is_valid(state)
&& (command_filter.len() == 1 || cmd.name.starts_with(&command_filter[..2]))
})
.collect();
let commands = fuzzy_filter(commands, |v| v.name, &command_filter);
if parts_len > 1 {
let span = Span::new(parts[parts_len - 1].1, pos);
let args_line = &line[parts[1].1..];
let args: Vec<&str> = parts.iter().skip(1).map(|(v, _)| *v).collect();
suggestions.extend(
self.config
.read()
.repl_complete(cmd, &args, args_line)
.iter()
.map(|(value, description)| {
let description = description.as_deref().unwrap_or_default();
create_suggestion(value, description, span)
}),
)
}
if suggestions.is_empty() {
let span = Span::new(cmd_start, pos);
suggestions.extend(commands.iter().map(|cmd| {
let name = cmd.name;
let description = cmd.description;
let has_group = self.groups.get(name).map(|v| *v > 1).unwrap_or_default();
let name = if has_group {
name.to_string()
} else {
format!("{name} ")
};
create_suggestion(&name, description, span)
}))
}
suggestions
}
}
pub struct ReplCompleter {
config: GlobalConfig,
commands: Vec<ReplCommand>,
groups: HashMap<&'static str, usize>,
}
impl ReplCompleter {
pub fn new(config: &GlobalConfig) -> Self {
let mut groups = HashMap::new();
let commands: Vec<ReplCommand> = REPL_COMMANDS.to_vec();
for cmd in REPL_COMMANDS.iter() {
let name = cmd.name;
*groups.entry(name).or_insert(0) += 1;
}
Self {
config: config.clone(),
commands,
groups,
}
}
}
fn create_suggestion(value: &str, description: &str, span: Span) -> Suggestion {
let description = if description.is_empty() {
None
} else {
Some(description.to_string())
};
Suggestion {
value: value.to_string(),
description,
style: None,
extra: None,
span,
append_whitespace: false,
}
}
fn split_line(line: &str) -> Vec<(&str, usize)> {
let mut parts = vec![];
let mut part_start = None;
for (i, ch) in line.char_indices() {
if ch == ' ' {
if let Some(s) = part_start {
parts.push((&line[s..i], s));
part_start = None;
}
} else if part_start.is_none() {
part_start = Some(i)
}
}
if let Some(s) = part_start {
parts.push((&line[s..], s));
} else {
parts.push(("", line.len()))
}
parts
}
#[test]
fn test_split_line() {
assert_eq!(split_line(".role coder"), vec![(".role", 0), ("coder", 6)],);
assert_eq!(
split_line(" .role coder"),
vec![(".role", 1), ("coder", 9)],
);
assert_eq!(
split_line(".set highlight "),
vec![(".set", 0), ("highlight", 5), ("", 15)],
);
assert_eq!(
split_line(".set highlight t"),
vec![(".set", 0), ("highlight", 5), ("t", 15)],
);
}
+49
View File
@@ -0,0 +1,49 @@
use super::REPL_COMMANDS;
use crate::{config::GlobalConfig, utils::NO_COLOR};
use nu_ansi_term::{Color, Style};
use reedline::{Highlighter, StyledText};
const DEFAULT_COLOR: Color = Color::Default;
const MATCH_COLOR: Color = Color::Green;
pub struct ReplHighlighter;
impl ReplHighlighter {
pub fn new(_config: &GlobalConfig) -> Self {
Self
}
}
impl Highlighter for ReplHighlighter {
fn highlight(&self, line: &str, _cursor: usize) -> StyledText {
let mut styled_text = StyledText::new();
if *NO_COLOR {
styled_text.push((Style::default(), line.to_string()));
} else if REPL_COMMANDS.iter().any(|cmd| line.contains(cmd.name)) {
let matches: Vec<&str> = REPL_COMMANDS
.iter()
.filter(|cmd| line.contains(cmd.name))
.map(|cmd| cmd.name)
.collect();
let longest_match = matches.iter().fold(String::new(), |acc, &item| {
if item.len() > acc.len() {
item.to_string()
} else {
acc
}
});
let buffer_split: Vec<&str> = line.splitn(2, &longest_match).collect();
styled_text.push((Style::new().fg(DEFAULT_COLOR), buffer_split[0].to_string()));
styled_text.push((Style::new().fg(MATCH_COLOR), longest_match));
styled_text.push((Style::new().fg(DEFAULT_COLOR), buffer_split[1].to_string()));
} else {
styled_text.push((Style::new().fg(DEFAULT_COLOR), line.to_string()));
}
styled_text
}
}
+1014
View File
File diff suppressed because it is too large Load Diff
+51
View File
@@ -0,0 +1,51 @@
use crate::config::GlobalConfig;
use reedline::{Prompt, PromptHistorySearch, PromptHistorySearchStatus};
use std::borrow::Cow;
#[derive(Clone)]
pub struct ReplPrompt {
config: GlobalConfig,
}
impl ReplPrompt {
pub fn new(config: &GlobalConfig) -> Self {
Self {
config: config.clone(),
}
}
}
impl Prompt for ReplPrompt {
fn render_prompt_left(&self) -> Cow<'_, str> {
Cow::Owned(self.config.read().render_prompt_left())
}
fn render_prompt_right(&self) -> Cow<'_, str> {
Cow::Owned(self.config.read().render_prompt_right())
}
fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow<'_, str> {
Cow::Borrowed("")
}
fn render_prompt_multiline_indicator(&self) -> Cow<'_, str> {
Cow::Borrowed("... ")
}
fn render_prompt_history_search_indicator(
&self,
history_search: PromptHistorySearch,
) -> Cow<'_, str> {
let prefix = match history_search.status {
PromptHistorySearchStatus::Passing => "",
PromptHistorySearchStatus::Failing => "failing ",
};
// NOTE: magic strings, given there is logic on how these compose I am not sure if it
// is worth extracting in to static constant
Cow::Owned(format!(
"({}reverse-search: {}) ",
prefix, history_search.term
))
}
}
+935
View File
@@ -0,0 +1,935 @@
use crate::{client::*, config::*, function::*, rag::*, utils::*};
use anyhow::{anyhow, bail, Result};
use bytes::Bytes;
use chrono::{Timelike, Utc};
use futures_util::StreamExt;
use http::{Method, Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
use hyper::{
body::{Frame, Incoming},
service::service_fn,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::{json, Value};
use std::{
convert::Infallible,
net::IpAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::{
net::TcpListener,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
};
use tokio_graceful::Shutdown;
use tokio_stream::wrappers::UnboundedReceiverStream;
const DEFAULT_MODEL_NAME: &str = "default";
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
type AppResponse = Response<BoxBody<Bytes, Infallible>>;
pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let addr = match addr {
Some(addr) => {
if let Ok(port) = addr.parse::<u16>() {
format!("127.0.0.1:{port}")
} else if let Ok(ip) = addr.parse::<IpAddr>() {
format!("{ip}:8000")
} else {
addr
}
}
None => config.read().serve_addr(),
};
let server = Arc::new(Server::new(&config));
let listener = TcpListener::bind(&addr).await?;
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("Rerank API: http://{addr}/v1/rerank");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena?num=2");
shutdown_signal().await;
let _ = stop_server.send(());
Ok(())
}
struct Server {
config: Config,
models: Vec<Value>,
roles: Vec<Role>,
rags: Vec<String>,
}
impl Server {
fn new(config: &GlobalConfig) -> Self {
let mut config = config.read().clone();
config.functions = Functions::default();
let mut models = list_all_models(&config);
let mut default_model = config.model.clone();
default_model.data_mut().name = DEFAULT_MODEL_NAME.into();
models.insert(0, &default_model);
let models: Vec<Value> = models
.into_iter()
.enumerate()
.map(|(i, model)| {
let id = if i == 0 {
DEFAULT_MODEL_NAME.into()
} else {
model.id()
};
let mut value = json!(model.data());
if let Some(value_obj) = value.as_object_mut() {
value_obj.insert("id".into(), id.into());
value_obj.insert("object".into(), "model".into());
value_obj.insert("owned_by".into(), model.client_name().into());
value_obj.remove("name");
}
value
})
.collect();
Self {
config,
models,
roles: Config::all_roles(),
rags: Config::list_rags(),
}
}
async fn run(self: Arc<Self>, listener: TcpListener) -> Result<oneshot::Sender<()>> {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() });
let guard = shutdown.guard_weak();
loop {
tokio::select! {
res = listener.accept() => {
let Ok((cnx, _)) = res else {
continue;
};
let stream = TokioIo::new(cnx);
let server = self.clone();
shutdown.spawn_task(async move {
let hyper_service = service_fn(move |request: hyper::Request<Incoming>| {
server.clone().handle(request)
});
let _ = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
});
}
_ = guard.cancelled() => {
break;
}
}
}
});
Ok(tx)
}
async fn handle(
self: Arc<Self>,
req: hyper::Request<Incoming>,
) -> std::result::Result<AppResponse, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path();
if method == Method::OPTIONS {
let mut res = Response::default();
*res.status_mut() = StatusCode::NO_CONTENT;
set_cors_header(&mut res);
return Ok(res);
}
let mut status = StatusCode::OK;
let res = if path == "/v1/chat/completions" {
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/rerank" {
self.rerank(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
self.list_roles()
} else if path == "/v1/rags" {
self.list_rags()
} else if path == "/v1/rags/search" {
self.search_rag(req).await
} else if path == "/playground" || path == "/playground.html" {
self.playground_page()
} else if path == "/arena" || path == "/arena.html" {
self.arena_page()
} else {
status = StatusCode::NOT_FOUND;
Err(anyhow!("Not Found"))
};
let mut res = match res {
Ok(res) => {
info!("{method} {uri} {}", status.as_u16());
res
}
Err(err) => {
if status == StatusCode::OK {
status = StatusCode::BAD_REQUEST;
}
error!("{method} {uri} {} {err}", status.as_u16());
ret_err(err)
}
};
*res.status_mut() = status;
set_cors_header(&mut res);
Ok(res)
}
fn playground_page(&self) -> Result<AppResponse> {
let res = Response::builder()
.header("Content-Type", "text/html; charset=utf-8")
.body(Full::new(Bytes::from(PLAYGROUND_HTML)).boxed())?;
Ok(res)
}
fn arena_page(&self) -> Result<AppResponse> {
let res = Response::builder()
.header("Content-Type", "text/html; charset=utf-8")
.body(Full::new(Bytes::from(ARENA_HTML)).boxed())?;
Ok(res)
}
fn list_models(&self) -> Result<AppResponse> {
let data = json!({ "data": self.models });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
fn list_roles(&self) -> Result<AppResponse> {
let data = json!({ "data": self.roles });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
fn list_rags(&self) -> Result<AppResponse> {
let data = json!({ "data": self.rags });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
async fn search_rag(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("search rag request: {req_body}");
let SearchRagReqBody { name, input } = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let config = Arc::new(RwLock::new(self.config.clone()));
let abort_signal = create_abort_signal();
let rag_path = config.read().rag_file(&name);
let rag = Rag::load(&config, &name, &rag_path)?;
let rag_result = Config::search_rag(&config, &rag, &input, abort_signal).await?;
let data = json!({ "data": rag_result });
let res = Response::builder()
.header("Content-Type", "application/json; charset=utf-8")
.body(Full::new(Bytes::from(data.to_string())).boxed())?;
Ok(res)
}
async fn chat_completions(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("chat completions request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let ChatCompletionsReqBody {
model,
messages,
temperature,
top_p,
max_tokens,
stream,
tools,
} = req_body;
let mut messages =
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
let config = self.config.clone();
let default_model = config.model.clone();
let config = Arc::new(RwLock::new(config));
let (model_name, change) = if model == DEFAULT_MODEL_NAME {
(default_model.id(), true)
} else if default_model.id() == model {
(model, false)
} else {
(model, true)
};
if change {
config.write().set_model(&model_name)?;
}
let mut client = init_client(&config, None)?;
if max_tokens.is_some() {
client.model_mut().set_max_tokens(max_tokens, true);
}
let abort_signal = create_abort_signal();
let http_client = client.build_client()?;
let completion_id = generate_completion_id();
let created = Utc::now().timestamp();
patch_messages(&mut messages, client.model());
let data: ChatCompletionsData = ChatCompletionsData {
messages,
temperature,
top_p,
functions,
stream,
};
if stream {
let (tx, mut rx) = unbounded_channel();
tokio::spawn(async move {
let is_first = Arc::new(AtomicBool::new(true));
let (sse_tx, sse_rx) = unbounded_channel();
let mut handler = SseHandler::new(sse_tx, abort_signal);
async fn map_event(
mut sse_rx: UnboundedReceiver<SseEvent>,
tx: &UnboundedSender<ResEvent>,
is_first: Arc<AtomicBool>,
) {
while let Some(reply_event) = sse_rx.recv().await {
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(None));
is_first.store(false, Ordering::SeqCst)
}
match reply_event {
SseEvent::Text(text) => {
let _ = tx.send(ResEvent::Text(text));
}
SseEvent::Done => {
let _ = tx.send(ResEvent::Done);
sse_rx.close();
}
}
}
}
async fn chat_completions(
client: &dyn Client,
http_client: &reqwest::Client,
handler: &mut SseHandler,
mut data: ChatCompletionsData,
tx: &UnboundedSender<ResEvent>,
is_first: Arc<AtomicBool>,
) {
if client.model().no_stream() {
data.stream = false;
let ret = client.chat_completions_inner(http_client, data).await;
match ret {
Ok(output) => {
let ChatCompletionsOutput {
text, tool_calls, ..
} = output;
let _ = tx.send(ResEvent::First(None));
is_first.store(false, Ordering::SeqCst);
let _ = tx.send(ResEvent::Text(text));
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
Err(err) => {
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
is_first.store(false, Ordering::SeqCst)
}
};
} else {
let ret = client
.chat_completions_streaming_inner(http_client, handler, data)
.await;
let first = match ret {
Ok(()) => None,
Err(err) => Some(format!("{err:?}")),
};
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(first));
is_first.store(false, Ordering::SeqCst)
}
let tool_calls = handler.tool_calls().to_vec();
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
handler.done();
}
tokio::join!(
map_event(sse_rx, &tx, is_first.clone()),
chat_completions(
client.as_ref(),
&http_client,
&mut handler,
data,
&tx,
is_first
),
);
});
let first_event = rx.recv().await;
if let Some(ResEvent::First(Some(err))) = first_event {
bail!("{err}");
}
let shared: Arc<(String, String, i64, AtomicBool)> =
Arc::new((completion_id, model_name, created, AtomicBool::new(false)));
let stream = UnboundedReceiverStream::new(rx);
let stream = stream.filter_map(move |res_event| {
let shared = shared.clone();
async move {
let (completion_id, model, created, has_tool_calls) = shared.as_ref();
match res_event {
ResEvent::Text(text) => {
Some(Ok(create_text_frame(completion_id, model, *created, &text)))
}
ResEvent::ToolCalls(tool_calls) => {
has_tool_calls.store(true, Ordering::SeqCst);
Some(Ok(create_tool_calls_frame(
completion_id,
model,
*created,
&tool_calls,
)))
}
ResEvent::Done => Some(Ok(create_done_frame(
completion_id,
model,
*created,
has_tool_calls.load(Ordering::SeqCst),
))),
_ => None,
}
}
});
let res = Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let output = client.chat_completions_inner(&http_client, data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(
Full::new(ret_non_stream(
&completion_id,
&model_name,
created,
&output,
))
.boxed(),
)?;
Ok(res)
}
}
async fn embeddings(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("embeddings request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let EmbeddingsReqBody {
input,
model: embedding_model_id,
} = req_body;
let config = Arc::new(RwLock::new(self.config.clone()));
let embedding_model =
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
let texts = match input {
EmbeddingsReqBodyInput::Single(v) => vec![v],
EmbeddingsReqBodyInput::Multiple(v) => v,
};
let client = init_client(&config, Some(embedding_model))?;
let data = client
.embeddings(&EmbeddingsData {
query: false,
texts,
})
.await?;
let data: Vec<_> = data
.into_iter()
.enumerate()
.map(|(i, v)| {
json!({
"object": "embedding",
"embedding": v,
"index": i,
})
})
.collect();
let output = json!({
"object": "list",
"data": data,
"model": embedding_model_id,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("rerank request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let RerankReqBody {
model: reranker_model_id,
documents,
query,
top_n,
} = req_body;
let top_n = top_n.unwrap_or(documents.len());
let config = Arc::new(RwLock::new(self.config.clone()));
let reranker_model =
Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?;
let client = init_client(&config, Some(reranker_model))?;
let data = client
.rerank(&RerankData {
query,
documents: documents.clone(),
top_n,
})
.await?;
let results: Vec<_> = data
.into_iter()
.map(|v| {
json!({
"index": v.index,
"relevance_score": v.relevance_score,
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
})
})
.collect();
let output = json!({
"id": uuid::Uuid::new_v4().to_string(),
"results": results,
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}
#[derive(Debug, Deserialize)]
struct SearchRagReqBody {
name: String,
input: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionsReqBody {
model: String,
messages: Vec<Value>,
temperature: Option<f64>,
top_p: Option<f64>,
max_tokens: Option<isize>,
#[serde(default)]
stream: bool,
tools: Option<Vec<Value>>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
input: EmbeddingsReqBodyInput,
model: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum EmbeddingsReqBodyInput {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Deserialize)]
struct RerankReqBody {
documents: Vec<String>,
query: String,
model: String,
top_n: Option<usize>,
}
#[derive(Debug)]
enum ResEvent {
First(Option<String>),
Text(String),
ToolCalls(Vec<ToolCall>),
Done,
}
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler")
}
fn generate_completion_id() -> String {
let random_id = Utc::now().nanosecond();
format!("chatcmpl-{random_id}")
}
fn set_cors_header(res: &mut AppResponse) {
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN,
hyper::header::HeaderValue::from_static("*"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_METHODS,
hyper::header::HeaderValue::from_static("GET,POST,PUT,PATCH,DELETE"),
);
res.headers_mut().insert(
hyper::header::ACCESS_CONTROL_ALLOW_HEADERS,
hyper::header::HeaderValue::from_static("Content-Type,Authorization"),
);
}
fn create_text_frame(id: &str, model: &str, created: i64, content: &str) -> Frame<Bytes> {
let delta = if content.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({ "content": content })
};
let choice = json!({
"index": 0,
"delta": delta,
"finish_reason": null,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\n")))
}
fn create_tool_calls_frame(
id: &str,
model: &str,
created: i64,
tool_calls: &[ToolCall],
) -> Frame<Bytes> {
let chunks = tool_calls
.iter()
.enumerate()
.flat_map(|(i, call)| {
let choice1 = json!({
"index": 0,
"delta": {
"role": "assistant",
"content": null,
"tool_calls": [
{
"index": i,
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": ""
}
}
]
},
"finish_reason": null
});
let choice2 = json!({
"index": 0,
"delta": {
"tool_calls": [
{
"index": i,
"function": {
"arguments": call.arguments.to_string(),
}
}
]
},
"finish_reason": null
});
vec![
build_chat_completion_chunk_json(id, model, created, &choice1),
build_chat_completion_chunk_json(id, model, created, &choice2),
]
})
.map(|v| format!("data: {v}\n\n"))
.collect::<Vec<String>>()
.join("");
Frame::data(Bytes::from(chunks))
}
fn create_done_frame(id: &str, model: &str, created: i64, has_tool_calls: bool) -> Frame<Bytes> {
let finish_reason = if has_tool_calls { "tool_calls" } else { "stop" };
let choice = json!({
"index": 0,
"delta": {},
"finish_reason": finish_reason,
});
let value = build_chat_completion_chunk_json(id, model, created, &choice);
Frame::data(Bytes::from(format!("data: {value}\n\ndata: [DONE]\n\n")))
}
fn build_chat_completion_chunk_json(id: &str, model: &str, created: i64, choice: &Value) -> Value {
json!({
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [choice],
})
}
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
let id = output.id.as_deref().unwrap_or(id);
let input_tokens = output.input_tokens.unwrap_or_default();
let output_tokens = output.output_tokens.unwrap_or_default();
let total_tokens = input_tokens + output_tokens;
let choice = if output.tool_calls.is_empty() {
json!({
"index": 0,
"message": {
"role": "assistant",
"content": output.text,
},
"logprobs": null,
"finish_reason": "stop",
})
} else {
let content = if output.text.is_empty() {
Value::Null
} else {
output.text.clone().into()
};
let tool_calls: Vec<_> = output
.tool_calls
.iter()
.map(|call| {
json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.arguments.to_string(),
}
})
})
.collect();
json!({
"index": 0,
"message": {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
},
"logprobs": null,
"finish_reason": "tool_calls",
})
};
let res_body = json!({
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [choice],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": total_tokens,
},
});
Bytes::from(res_body.to_string())
}
fn ret_err<T: std::fmt::Display>(err: T) -> AppResponse {
let data = json!({
"error": {
"message": err.to_string(),
"type": "invalid_request_error",
},
});
Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(data.to_string())).boxed())
.unwrap()
}
fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
let mut output = vec![];
let mut tool_results = None;
for (i, message) in message.into_iter().enumerate() {
let err = || anyhow!("Failed to parse '.messages[{i}]'");
let role = message["role"].as_str().ok_or_else(err)?;
let content = match message.get("content") {
Some(value) => {
if let Some(value) = value.as_str() {
MessageContent::Text(value.to_string())
} else if value.is_array() {
let value = serde_json::from_value(value.clone()).map_err(|_| err())?;
MessageContent::Array(value)
} else if value.is_null() {
MessageContent::Text(String::new())
} else {
return Err(err());
}
}
None => MessageContent::Text(String::new()),
};
match role {
"system" | "user" => {
let role = match role {
"system" => MessageRole::System,
"user" => MessageRole::User,
_ => unreachable!(),
};
output.push(Message::new(role, content))
}
"assistant" => {
let role = MessageRole::Assistant;
match message["tool_calls"].as_array() {
Some(tool_calls) => {
if tool_results.is_some() {
return Err(err());
}
let mut list = vec![];
for tool_call in tool_calls {
if let (id, Some(name), Some(arguments)) = (
tool_call["id"].as_str().map(|v| v.to_string()),
tool_call["function"]["name"].as_str(),
tool_call["function"]["arguments"].as_str(),
) {
let arguments =
serde_json::from_str(arguments).map_err(|_| err())?;
list.push((id, name.to_string(), arguments));
} else {
return Err(err());
}
}
tool_results = Some((content.to_text(), list, vec![]));
}
None => output.push(Message::new(role, content)),
}
}
"tool" => match tool_results.take() {
Some((text, tool_calls, mut tool_values)) => {
let tool_call_id = message["tool_call_id"].as_str().map(|v| v.to_string());
let content = content.to_text();
let value: Value = serde_json::from_str(&content)
.ok()
.unwrap_or_else(|| content.into());
tool_values.push((value, tool_call_id));
if tool_calls.len() == tool_values.len() {
let mut list = vec![];
for ((id, name, arguments), (value, tool_call_id)) in
tool_calls.into_iter().zip(tool_values.into_iter())
{
if id != tool_call_id {
return Err(err());
}
list.push(ToolResult::new(ToolCall::new(name, arguments, id), value))
}
output.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
));
tool_results = None;
} else {
tool_results = Some((text, tool_calls, tool_values));
}
}
None => return Err(err()),
},
_ => {
return Err(err());
}
}
}
if tool_results.is_some() {
bail!("Invalid messages");
}
Ok(output)
}
fn parse_tools(tools: Option<Vec<Value>>) -> Result<Option<Vec<FunctionDeclaration>>> {
let tools = match tools {
Some(v) => v,
None => return Ok(None),
};
let mut functions = vec![];
for (i, tool) in tools.into_iter().enumerate() {
if let (Some("function"), Some(function)) = (
tool["type"].as_str(),
tool["function"]
.as_object()
.and_then(|v| serde_json::from_value(json!(v)).ok()),
) {
functions.push(function);
} else {
bail!("Failed to parse '.tools[{i}]'")
}
}
Ok(Some(functions))
}
+88
View File
@@ -0,0 +1,88 @@
use anyhow::Result;
use crossterm::event::{self, Event, KeyCode, KeyModifiers};
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
pub type AbortSignal = Arc<AbortSignalInner>;
pub struct AbortSignalInner {
ctrlc: AtomicBool,
ctrld: AtomicBool,
}
pub fn create_abort_signal() -> AbortSignal {
AbortSignalInner::new()
}
impl AbortSignalInner {
pub fn new() -> AbortSignal {
Arc::new(Self {
ctrlc: AtomicBool::new(false),
ctrld: AtomicBool::new(false),
})
}
pub fn aborted(&self) -> bool {
if self.aborted_ctrlc() {
return true;
}
if self.aborted_ctrld() {
return true;
}
false
}
pub fn aborted_ctrlc(&self) -> bool {
self.ctrlc.load(Ordering::SeqCst)
}
pub fn aborted_ctrld(&self) -> bool {
self.ctrld.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.ctrlc.store(false, Ordering::SeqCst);
self.ctrld.store(false, Ordering::SeqCst);
}
pub fn set_ctrlc(&self) {
self.ctrlc.store(true, Ordering::SeqCst);
}
pub fn set_ctrld(&self) {
self.ctrld.store(true, Ordering::SeqCst);
}
}
pub async fn wait_abort_signal(abort_signal: &AbortSignal) {
loop {
if abort_signal.aborted() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
}
pub fn poll_abort_signal(abort_signal: &AbortSignal) -> Result<bool> {
if event::poll(Duration::from_millis(25))? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
abort_signal.set_ctrlc();
return Ok(true);
}
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
abort_signal.set_ctrld();
return Ok(true);
}
_ => {}
}
}
}
Ok(false)
}
+49
View File
@@ -0,0 +1,49 @@
use anyhow::Context;
#[cfg(not(any(target_os = "android", target_os = "emscripten")))]
mod internal {
use arboard::Clipboard;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use std::sync::{LazyLock, Mutex};
static CLIPBOARD: LazyLock<Mutex<Option<Clipboard>>> =
LazyLock::new(|| Mutex::new(Clipboard::new().ok()));
pub fn set_text(text: &str) -> anyhow::Result<()> {
let mut clipboard = CLIPBOARD.lock().unwrap();
match clipboard.as_mut() {
Some(clipboard) => {
clipboard.set_text(text)?;
#[cfg(target_os = "linux")]
std::thread::sleep(std::time::Duration::from_millis(50));
Ok(())
}
None => set_text_osc52(text),
}
}
/// Attempts to set text to clipboard with OSC52 escape sequence
/// Works in many modern terminals, including over SSH.
fn set_text_osc52(text: &str) -> anyhow::Result<()> {
let encoded = STANDARD.encode(text);
let seq = format!("\x1b]52;c;{encoded}\x07");
if let Err(e) = std::io::Write::write_all(&mut std::io::stdout(), seq.as_bytes()) {
return Err(anyhow::anyhow!("Failed to send OSC52 sequence").context(e));
}
if let Err(e) = std::io::Write::flush(&mut std::io::stdout()) {
return Err(anyhow::anyhow!("Failed to flush OSC52 sequence").context(e));
}
Ok(())
}
}
#[cfg(any(target_os = "android", target_os = "emscripten"))]
mod internal {
pub fn set_text(_text: &str) -> anyhow::Result<()> {
Err(anyhow::anyhow!("No clipboard available"))
}
}
pub fn set_text(text: &str) -> anyhow::Result<()> {
internal::set_text(text).context("Failed to copy")
}
+242
View File
@@ -0,0 +1,242 @@
use super::*;
use std::{
collections::HashMap,
env,
ffi::OsStr,
fs::OpenOptions,
io::{self, Write},
path::{Path, PathBuf},
process::Command,
};
use anyhow::{anyhow, bail, Context, Result};
use dirs::home_dir;
use std::sync::LazyLock;
pub static SHELL: LazyLock<Shell> = LazyLock::new(detect_shell);
pub struct Shell {
pub name: String,
pub cmd: String,
pub arg: String,
}
impl Shell {
pub fn new(name: &str, cmd: &str, arg: &str) -> Self {
Self {
name: name.to_string(),
cmd: cmd.to_string(),
arg: arg.to_string(),
}
}
}
pub fn detect_shell() -> Shell {
let cmd = env::var(get_env_name("shell")).ok().or_else(|| {
if cfg!(windows) {
if let Ok(ps_module_path) = env::var("PSModulePath") {
let ps_module_path = ps_module_path.to_lowercase();
if ps_module_path.starts_with(r"c:\users") {
return if ps_module_path.contains(r"\powershell\7\") {
Some("pwsh.exe".to_string())
} else {
Some("powershell.exe".to_string())
};
}
}
None
} else {
env::var("SHELL").ok()
}
});
let name = cmd
.as_ref()
.and_then(|v| Path::new(v).file_stem().and_then(|v| v.to_str()))
.map(|v| {
if v == "nu" {
"nushell".into()
} else {
v.to_lowercase()
}
});
let (cmd, name) = match (cmd.as_deref(), name.as_deref()) {
(Some(cmd), Some(name)) => (cmd, name),
_ => {
if cfg!(windows) {
("cmd.exe", "cmd")
} else {
("/bin/sh", "sh")
}
}
};
let shell_arg = match name {
"powershell" => "-Command",
"cmd" => "/C",
_ => "-c",
};
Shell::new(name, cmd, shell_arg)
}
pub fn run_command<T: AsRef<OsStr>>(
cmd: &str,
args: &[T],
envs: Option<HashMap<String, String>>,
) -> Result<i32> {
let status = Command::new(cmd)
.args(args.iter())
.envs(envs.unwrap_or_default())
.status()?;
Ok(status.code().unwrap_or_default())
}
pub fn run_command_with_output<T: AsRef<OsStr>>(
cmd: &str,
args: &[T],
envs: Option<HashMap<String, String>>,
) -> Result<(bool, String, String)> {
let output = Command::new(cmd)
.args(args.iter())
.envs(envs.unwrap_or_default())
.output()?;
let status = output.status;
let stdout = std::str::from_utf8(&output.stdout).context("Invalid UTF-8 in stdout")?;
let stderr = std::str::from_utf8(&output.stderr).context("Invalid UTF-8 in stderr")?;
if !status.success() {
debug!("Command `{cmd}` exited with non-zero: {status}");
}
if !stdout.is_empty() {
debug!("Command `{cmd}` exited with non-zero. stderr: {stderr}");
}
if !stderr.is_empty() {
debug!("Command `{cmd}` executed successfully. stdout: {stdout}");
}
Ok((status.success(), stdout.to_string(), stderr.to_string()))
}
pub fn run_loader_command(path: &str, extension: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command)
.with_context(|| anyhow!("Invalid document loader '{extension}': `{loader_command}`"))?;
let mut use_stdout = true;
let outpath = temp_file("-output-", "").display().to_string();
let cmd_args: Vec<_> = cmd_args
.into_iter()
.map(|mut v| {
if v.contains("$1") {
v = v.replace("$1", path);
}
if v.contains("$2") {
use_stdout = false;
v = v.replace("$2", &outpath);
}
v
})
.collect();
let cmd_eval = shell_words::join(&cmd_args);
debug!("run `{cmd_eval}`");
let (cmd, args) = cmd_args.split_at(1);
let cmd = &cmd[0];
if use_stdout {
let (success, stdout, stderr) =
run_command_with_output(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if !success {
let err = if !stderr.is_empty() {
stderr
} else {
format!("The command `{cmd_eval}` exited with non-zero.")
};
bail!("{err}")
}
Ok(stdout)
} else {
let status = run_command(cmd, args, None).with_context(|| {
format!("Unable to run `{cmd_eval}`, Perhaps '{cmd}' is not installed?")
})?;
if status != 0 {
bail!("The command `{cmd_eval}` exited with non-zero.")
}
let contents = std::fs::read_to_string(&outpath)
.context("Failed to read file generated by the loader")?;
Ok(contents)
}
}
pub fn edit_file(editor: &str, path: &Path) -> Result<()> {
let mut child = Command::new(editor).arg(path).spawn()?;
child.wait()?;
Ok(())
}
pub fn append_to_shell_history(shell: &str, command: &str, exit_code: i32) -> io::Result<()> {
if let Some(history_file) = get_history_file(shell) {
let command = command.replace('\n', " ");
let now = now_timestamp();
let history_txt = if shell == "fish" {
format!("- cmd: {command}\n when: {now}")
} else if shell == "zsh" {
format!(": {now}:{exit_code};{command}",)
} else {
command
};
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&history_file)?;
writeln!(file, "{history_txt}")?;
}
Ok(())
}
fn get_history_file(shell: &str) -> Option<PathBuf> {
match shell {
"bash" | "sh" => env::var("HISTFILE")
.ok()
.map(PathBuf::from)
.or(Some(home_dir()?.join(".bash_history"))),
"zsh" => env::var("HISTFILE")
.ok()
.map(PathBuf::from)
.or(Some(home_dir()?.join(".zsh_history"))),
"nushell" => Some(dirs::config_dir()?.join("nushell").join("history.txt")),
"fish" => Some(
home_dir()?
.join(".local")
.join("share")
.join("fish")
.join("fish_history"),
),
"powershell" | "pwsh" => {
#[cfg(not(windows))]
{
Some(
home_dir()?
.join(".local")
.join("share")
.join("powershell")
.join("PSReadLine")
.join("ConsoleHost_history.txt"),
)
}
#[cfg(windows)]
{
Some(
dirs::data_dir()?
.join("Microsoft")
.join("Windows")
.join("PowerShell")
.join("PSReadLine")
.join("ConsoleHost_history.txt"),
)
}
}
"ksh" => Some(home_dir()?.join(".ksh_history")),
"tcsh" => Some(home_dir()?.join(".history")),
_ => None,
}
}
+35
View File
@@ -0,0 +1,35 @@
use base64::{engine::general_purpose::STANDARD, Engine};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
pub fn sha256(input: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(input);
format!("{:x}", hasher.finalize())
}
pub fn hmac_sha256(key: &[u8], msg: &str) -> Vec<u8> {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
mac.update(msg.as_bytes());
mac.finalize().into_bytes().to_vec()
}
pub fn hex_encode(bytes: &[u8]) -> String {
bytes
.iter()
.fold(String::new(), |acc, b| acc + &format!("{b:02x}"))
}
pub fn encode_uri(uri: &str) -> String {
uri.split('/')
.map(|v| urlencoding::encode(v))
.collect::<Vec<_>>()
.join("/")
}
pub fn base64_encode<T: AsRef<[u8]>>(input: T) -> String {
STANDARD.encode(input)
}
pub fn base64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
STANDARD.decode(input)
}
+18
View File
@@ -0,0 +1,18 @@
use std::{cell::RefCell, rc::Rc};
use html_to_markdown::{markdown, TagHandler};
pub fn html_to_md(html: &str) -> String {
let mut handlers: Vec<TagHandler> = vec![
Rc::new(RefCell::new(markdown::ParagraphHandler)),
Rc::new(RefCell::new(markdown::HeadingHandler)),
Rc::new(RefCell::new(markdown::ListHandler)),
Rc::new(RefCell::new(markdown::TableHandler::new())),
Rc::new(RefCell::new(markdown::StyledTextHandler)),
Rc::new(RefCell::new(markdown::CodeHandler)),
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
];
html_to_markdown::convert_html_to_markdown(html.as_bytes(), &mut handlers)
.unwrap_or_else(|_| html.to_string())
}
+47
View File
@@ -0,0 +1,47 @@
use anyhow::Result;
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
use std::io::{stdout, Write};
/// Reads a single character from stdin without requiring Enter
/// Returns the character if it's one of the valid options, or the default if Enter is pressed
pub fn read_single_key(valid_chars: &[char], default: char, prompt: &str) -> Result<char> {
print!("{prompt}");
stdout().flush()?;
enable_raw_mode()?;
let result = loop {
if let Ok(Event::Key(KeyEvent {
code, modifiers, ..
})) = event::read()
{
match code {
KeyCode::Char('c') if modifiers.contains(KeyModifiers::CONTROL) => {
break Err(anyhow::anyhow!("Interrupted"));
}
KeyCode::Char(c) => {
if valid_chars.contains(&c) {
break Ok(c);
}
// Invalid character, continue loop
}
KeyCode::Enter => {
break Ok(default);
}
_ => {
// Other keys are ignored, continue loop
}
}
}
};
disable_raw_mode()?;
// Print the chosen character and newline for clean output
if let Ok(chosen) = &result {
println!("{chosen}");
}
result
}
+125
View File
@@ -0,0 +1,125 @@
use super::*;
use anyhow::{anyhow, Context, Result};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const EXTENSION_METADATA: &str = "__extension__";
pub type DocumentMetadata = IndexMap<String, String>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadedDocument {
pub path: String,
pub contents: String,
#[serde(default)]
pub metadata: DocumentMetadata,
}
impl LoadedDocument {
pub fn new(path: String, contents: String, metadata: DocumentMetadata) -> Self {
Self {
path,
contents,
metadata,
}
}
}
pub async fn load_recursive_url(
loaders: &HashMap<String, String>,
path: &str,
) -> Result<Vec<LoadedDocument>> {
let extension = RECURSIVE_URL_LOADER;
let pages: Vec<Page> = match loaders.get(extension) {
Some(loader_command) => {
let contents = run_loader_command(path, extension, loader_command)?;
serde_json::from_str(&contents).context(r#"The crawler response is invalid. It should follow the JSON format: `[{"path":"...", "text":"..."}]`."#)?
}
None => {
let options = CrawlOptions::preset(path);
crawl_website(path, options).await?
}
};
let output = pages
.into_iter()
.map(|v| {
let Page { path, text } = v;
let mut metadata: DocumentMetadata = Default::default();
metadata.insert(EXTENSION_METADATA.into(), "md".into());
LoadedDocument::new(path, text, metadata)
})
.collect();
Ok(output)
}
pub async fn load_file(loaders: &HashMap<String, String>, path: &str) -> Result<LoadedDocument> {
let extension = get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into());
match loaders.get(&extension) {
Some(loader_command) => load_with_command(path, &extension, loader_command),
None => load_plain(path, &extension).await,
}
}
pub async fn load_url(loaders: &HashMap<String, String>, path: &str) -> Result<LoadedDocument> {
let (contents, extension) = fetch_with_loaders(loaders, path, false).await?;
let mut metadata: DocumentMetadata = Default::default();
metadata.insert(EXTENSION_METADATA.into(), extension);
Ok(LoadedDocument::new(path.into(), contents, metadata))
}
async fn load_plain(path: &str, extension: &str) -> Result<LoadedDocument> {
let contents = tokio::fs::read_to_string(path).await?;
let mut metadata: DocumentMetadata = Default::default();
metadata.insert(EXTENSION_METADATA.into(), extension.to_string());
Ok(LoadedDocument::new(path.into(), contents, metadata))
}
fn load_with_command(path: &str, extension: &str, loader_command: &str) -> Result<LoadedDocument> {
let contents = run_loader_command(path, extension, loader_command)?;
let mut metadata: DocumentMetadata = Default::default();
metadata.insert(EXTENSION_METADATA.into(), DEFAULT_EXTENSION.to_string());
Ok(LoadedDocument::new(path.into(), contents, metadata))
}
pub fn is_loader_protocol(loaders: &HashMap<String, String>, path: &str) -> bool {
match path.split_once(':') {
Some((protocol, _)) => loaders.contains_key(protocol),
None => false,
}
}
pub fn load_protocol_path(
loaders: &HashMap<String, String>,
path: &str,
) -> Result<Vec<LoadedDocument>> {
let (protocol, loader_command, new_path) = path
.split_once(':')
.and_then(|(protocol, path)| {
let loader_command = loaders.get(protocol)?;
Some((protocol, loader_command, path))
})
.ok_or_else(|| anyhow!("No document loader for '{}'", path))?;
let contents = run_loader_command(new_path, protocol, loader_command)?;
let output = if let Ok(list) = serde_json::from_str::<Vec<LoadedDocument>>(&contents) {
list.into_iter()
.map(|mut v| {
if v.path.starts_with(path) {
} else if v.path.starts_with(new_path) {
v.path = format!("{}:{}", protocol, v.path);
} else {
v.path = format!("{}/{}", path, v.path);
}
v
})
.collect()
} else {
vec![LoadedDocument::new(
path.into(),
contents,
Default::default(),
)]
};
Ok(output)
}
+63
View File
@@ -0,0 +1,63 @@
use crate::config::Config;
use colored::Colorize;
use fancy_regex::Regex;
use std::fs::File;
use std::io::{BufRead, BufReader, Seek, SeekFrom};
use std::process;
pub async fn tail_logs(no_color: bool) {
let re = Regex::new(r"^(?P<timestamp>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3})\s+<(?P<opid>[^\s>]+)>\s+\[(?P<level>[A-Z]+)\]\s+(?P<logger>[^:]+):(?P<line>\d+)\s+-\s+(?P<message>.*)$").unwrap();
let file_path = Config::log_path();
let file = File::open(&file_path).expect("Cannot open file");
let mut reader = BufReader::new(file);
if let Err(e) = reader.seek(SeekFrom::End(0)) {
eprintln!("Unable to tail log file: {e:?}");
process::exit(1);
};
let mut lines = reader.lines();
loop {
if let Some(Ok(line)) = lines.next() {
if no_color {
println!("{line}");
} else {
let colored_line = colorize_log_line(&line, &re);
println!("{colored_line}");
}
}
}
}
fn colorize_log_line(line: &str, re: &Regex) -> String {
if let Some(caps) = re.captures(line).expect("Failed to capture log line") {
let level = &caps["level"];
let message = &caps["message"];
let colored_message = match level {
"ERROR" => message.red(),
"WARN" => message.yellow(),
"INFO" => message.green(),
"DEBUG" => message.blue(),
_ => message.normal(),
};
let timestamp = &caps["timestamp"];
let opid = &caps["opid"];
let logger = &caps["logger"];
let line_number = &caps["line"];
format!(
"{} <{}> [{}] {}:{} - {}",
timestamp.white(),
opid.cyan(),
level.bold(),
logger.magenta(),
line_number.bold(),
colored_message
)
} else {
line.to_string()
}
}
+252
View File
@@ -0,0 +1,252 @@
mod abort_signal;
mod clipboard;
mod command;
mod crypto;
mod html_to_md;
mod input;
mod loader;
mod logs;
pub mod native;
mod path;
mod render_prompt;
mod request;
mod spinner;
mod variables;
pub use self::abort_signal::*;
pub use self::clipboard::set_text;
pub use self::command::*;
pub use self::crypto::*;
pub use self::html_to_md::*;
pub use self::input::*;
pub use self::loader::*;
pub use self::logs::*;
pub use self::path::*;
pub use self::render_prompt::render_prompt;
pub use self::request::*;
pub use self::spinner::*;
pub use self::variables::*;
use anyhow::{Context, Result};
use fancy_regex::Regex;
use fuzzy_matcher::{skim::SkimMatcherV2, FuzzyMatcher};
use is_terminal::IsTerminal;
use std::borrow::Cow;
use std::sync::LazyLock;
use std::{env, path::PathBuf, process};
use unicode_segmentation::UnicodeSegmentation;
pub static CODE_BLOCK_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?ms)```\w*(.*)```").unwrap());
pub static THINK_TAG_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?s)^\s*<think>.*?</think>(\s*|$)").unwrap());
pub static IS_STDOUT_TERMINAL: LazyLock<bool> = LazyLock::new(|| std::io::stdout().is_terminal());
pub static NO_COLOR: LazyLock<bool> = LazyLock::new(|| {
env::var("NO_COLOR")
.ok()
.and_then(|v| parse_bool(&v))
.unwrap_or_default()
|| !*IS_STDOUT_TERMINAL
});
pub fn now() -> String {
chrono::Local::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, false)
}
pub fn now_timestamp() -> i64 {
chrono::Local::now().timestamp()
}
pub fn get_env_name(key: &str) -> String {
format!("{}_{key}", env!("CARGO_CRATE_NAME"),).to_ascii_uppercase()
}
pub fn normalize_env_name(value: &str) -> String {
value.replace('-', "_").to_ascii_uppercase()
}
pub fn parse_bool(value: &str) -> Option<bool> {
match value {
"1" | "true" => Some(true),
"0" | "false" => Some(false),
_ => None,
}
}
pub fn estimate_token_length(text: &str) -> usize {
let words: Vec<&str> = text.unicode_words().collect();
let mut output: f32 = 0.0;
for word in words {
if word.is_ascii() {
output += 1.3;
} else {
let count = word.chars().count();
if count == 1 {
output += 1.0
} else {
output += (count as f32) * 0.5;
}
}
}
output.ceil() as usize
}
pub fn strip_think_tag(text: &str) -> Cow<'_, str> {
THINK_TAG_RE.replace_all(text, "")
}
pub fn extract_code_block(text: &str) -> &str {
CODE_BLOCK_RE
.captures(text)
.ok()
.and_then(|v| v?.get(1).map(|v| v.as_str().trim()))
.unwrap_or(text)
}
pub fn convert_option_string(value: &str) -> Option<String> {
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
pub fn fuzzy_filter<T, F>(values: Vec<T>, get: F, pattern: &str) -> Vec<T>
where
F: Fn(&T) -> &str,
{
let matcher = SkimMatcherV2::default();
let mut list: Vec<(T, i64)> = values
.into_iter()
.filter_map(|v| {
let score = matcher.fuzzy_match(get(&v), pattern)?;
Some((v, score))
})
.collect();
list.sort_unstable_by(|a, b| b.1.cmp(&a.1));
list.into_iter().map(|(v, _)| v).collect()
}
pub fn pretty_error(err: &anyhow::Error) -> String {
let mut output = vec![];
output.push(format!("Error: {err}"));
let causes: Vec<_> = err.chain().skip(1).collect();
let causes_len = causes.len();
if causes_len > 0 {
output.push("\nCaused by:".to_string());
if causes_len == 1 {
output.push(format!(" {}", indent_text(causes[0], 4).trim()));
} else {
for (i, cause) in causes.into_iter().enumerate() {
output.push(format!("{i:5}: {}", indent_text(cause, 7).trim()));
}
}
}
output.join("\n")
}
pub fn indent_text<T: ToString>(s: T, size: usize) -> String {
let indent_str = " ".repeat(size);
s.to_string()
.split('\n')
.map(|line| format!("{indent_str}{line}"))
.collect::<Vec<String>>()
.join("\n")
}
pub fn error_text(input: &str) -> String {
color_text(input, nu_ansi_term::Color::Red)
}
pub fn warning_text(input: &str) -> String {
color_text(input, nu_ansi_term::Color::Yellow)
}
pub fn color_text(input: &str, color: nu_ansi_term::Color) -> String {
if *NO_COLOR {
return input.to_string();
}
nu_ansi_term::Style::new()
.fg(color)
.paint(input)
.to_string()
}
pub fn dimmed_text(input: &str) -> String {
if *NO_COLOR {
return input.to_string();
}
nu_ansi_term::Style::new().dimmed().paint(input).to_string()
}
pub fn multiline_text(input: &str) -> String {
input
.split('\n')
.enumerate()
.map(|(i, v)| {
if i == 0 {
v.to_string()
} else {
format!(".. {v}")
}
})
.collect::<Vec<String>>()
.join("\n")
}
pub fn temp_file(prefix: &str, suffix: &str) -> PathBuf {
env::temp_dir().join(format!(
"{}-{}{prefix}{}{suffix}",
env!("CARGO_CRATE_NAME").to_lowercase(),
process::id(),
uuid::Uuid::new_v4()
))
}
pub fn is_url(path: &str) -> bool {
path.starts_with("http://") || path.starts_with("https://")
}
pub fn set_proxy(
mut builder: reqwest::ClientBuilder,
proxy: &str,
) -> Result<reqwest::ClientBuilder> {
builder = builder.no_proxy();
if !proxy.is_empty() && proxy != "-" {
builder = builder
.proxy(reqwest::Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?);
};
Ok(builder)
}
pub fn decode_bin<T: serde::de::DeserializeOwned>(data: &[u8]) -> Result<T> {
let (v, _) = bincode::serde::decode_from_slice(data, bincode::config::legacy())?;
Ok(v)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(not(target_os = "windows"))]
fn test_safe_join_path() {
assert_eq!(
safe_join_path("/home/user/dir1", "files/file1"),
Some(PathBuf::from("/home/user/dir1/files/file1"))
);
assert!(safe_join_path("/home/user/dir1", "/files/file1").is_none());
assert!(safe_join_path("/home/user/dir1", "../file1").is_none());
}
#[test]
#[cfg(target_os = "windows")]
fn test_safe_join_path() {
assert_eq!(
safe_join_path("C:\\Users\\user\\dir1", "files/file1"),
Some(PathBuf::from("C:\\Users\\user\\dir1\\files\\file1"))
);
assert!(safe_join_path("C:\\Users\\user\\dir1", "/files/file1").is_none());
assert!(safe_join_path("C:\\Users\\user\\dir1", "../file1").is_none());
}
}
+46
View File
@@ -0,0 +1,46 @@
#[cfg(windows)]
pub mod runtime {
use std::path::Path;
pub fn bash_path() -> Option<String> {
let bash_path = "C:\\Program Files\\Git\\bin\\bash.exe";
if exist_path(bash_path) {
return Some(bash_path.into());
}
let git_path = which("git")?;
let git_parent_path = parent_path(&git_path)?;
let bash_path = join_path(&parent_path(&git_parent_path)?, &["bin", "bash.exe"]);
if exist_path(&bash_path) {
return Some(bash_path);
}
let bash_path = join_path(&git_parent_path, &["bash.exe"]);
if exist_path(&bash_path) {
return Some(bash_path);
}
None
}
fn exist_path(path: &str) -> bool {
Path::new(path).exists()
}
pub fn which(name: &str) -> Option<String> {
which::which(name)
.ok()
.map(|path| path.to_string_lossy().into())
}
fn parent_path(path: &str) -> Option<String> {
Path::new(path)
.parent()
.map(|path| path.to_string_lossy().into())
}
fn join_path(path: &str, parts: &[&str]) -> String {
let mut path = Path::new(path).to_path_buf();
for part in parts {
path = path.join(part);
}
path.to_string_lossy().into()
}
}
+356
View File
@@ -0,0 +1,356 @@
use std::fs;
use std::path::{Component, Path, PathBuf};
use anyhow::{bail, Result};
use fancy_regex::Regex;
use indexmap::IndexSet;
use path_absolutize::Absolutize;
type ParseGlobResult = (String, Option<Vec<String>>, bool, Option<usize>);
pub fn safe_join_path<T1: AsRef<Path>, T2: AsRef<Path>>(
base_path: T1,
sub_path: T2,
) -> Option<PathBuf> {
let base_path = base_path.as_ref();
let sub_path = sub_path.as_ref();
if sub_path.is_absolute() {
return None;
}
let mut joined_path = PathBuf::from(base_path);
for component in sub_path.components() {
if Component::ParentDir == component {
return None;
}
joined_path.push(component);
}
if joined_path.starts_with(base_path) {
Some(joined_path)
} else {
None
}
}
pub async fn expand_glob_paths<T: AsRef<str>>(
paths: &[T],
bail_non_exist: bool,
) -> Result<IndexSet<String>> {
let mut new_paths = IndexSet::new();
for path in paths {
let (path_str, suffixes, current_only, depth) = parse_glob(path.as_ref())?;
list_files(
&mut new_paths,
Path::new(&path_str),
suffixes.as_ref(),
current_only,
bail_non_exist,
depth,
)
.await?;
}
Ok(new_paths)
}
pub fn clear_dir(dir: &Path) -> Result<()> {
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
fs::remove_dir_all(&path)?;
} else {
fs::remove_file(&path)?;
}
}
Ok(())
}
pub fn list_file_names<T: AsRef<Path>>(dir: T, ext: &str) -> Vec<String> {
match fs::read_dir(dir.as_ref()) {
Ok(rd) => {
let mut names = vec![];
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(ext) {
names.push(name.to_string());
}
}
names.sort_unstable();
names
}
Err(_) => vec![],
}
}
pub fn get_patch_extension(path: &str) -> Option<String> {
Path::new(&path)
.extension()
.map(|v| v.to_string_lossy().to_lowercase())
}
pub fn to_absolute_path(path: &str) -> Result<String> {
Ok(Path::new(&path).absolutize()?.display().to_string())
}
pub fn resolve_home_dir(path: &str) -> String {
let mut path = path.to_string();
if path.starts_with("~/") || path.starts_with("~\\") {
if let Some(home_dir) = dirs::home_dir() {
path.replace_range(..1, &home_dir.display().to_string());
}
}
path
}
fn parse_glob(path_str: &str) -> Result<ParseGlobResult> {
let globbed_single_subdir_regex = Regex::new(r"\*/[^/]+\.[^/]+$").expect("invalid regex");
let globbed_recursive_subdir_regex = Regex::new(r"\*\*/[^/]+\.[^/]+$").expect("invalid regex");
let glob_result =
if let Some(start) = path_str.find("/**/*.").or_else(|| path_str.find(r"\**\*.")) {
Some((start, 6, false, None))
} else if let Some(start) = path_str.find("**/*.").or_else(|| path_str.find(r"**\*.")) {
if start == 0 {
Some((start, 5, false, None))
} else {
None
}
} else if let Some(m) = globbed_recursive_subdir_regex.find(path_str)? {
Some((m.start(), 3, false, None))
} else if let Some(m) = globbed_single_subdir_regex.find(path_str)? {
Some((m.start(), 2, false, Some(1usize)))
} else if let Some(start) = path_str.find("/*.").or_else(|| path_str.find(r"\*.")) {
Some((start, 3, true, None))
} else if let Some(start) = path_str.find("*.") {
if start == 0 {
Some((start, 2, true, None))
} else {
None
}
} else {
None
};
if let Some((start, offset, current_only, depth)) = glob_result {
let mut base_path = path_str[..start].to_string();
if base_path.is_empty() {
base_path = if path_str
.chars()
.next()
.map(|v| v == '/')
.unwrap_or_default()
{
"/"
} else {
"."
}
.into();
}
let extensions = if let Some(curly_brace_end) = path_str[start..].find('}') {
let end = start + curly_brace_end;
let extensions_str = &path_str[start + offset..end + 1];
if extensions_str.starts_with('{') && extensions_str.ends_with('}') {
extensions_str[1..extensions_str.len() - 1]
.split(',')
.map(|s| s.to_string())
.collect::<Vec<String>>()
} else {
bail!("Invalid path '{path_str}'");
}
} else {
let extensions_str = &path_str[start + offset..];
vec![extensions_str.to_string()]
};
let extensions = if extensions.is_empty() {
None
} else {
Some(extensions)
};
Ok((base_path, extensions, current_only, depth))
} else if path_str.ends_with("/**") || path_str.ends_with(r"\**") {
Ok((
path_str[0..path_str.len() - 3].to_string(),
None,
false,
None,
))
} else {
Ok((path_str.to_string(), None, false, None))
}
}
#[async_recursion::async_recursion]
async fn list_files(
files: &mut IndexSet<String>,
entry_path: &Path,
suffixes: Option<&Vec<String>>,
current_only: bool,
bail_non_exist: bool,
depth: Option<usize>,
) -> Result<()> {
if !entry_path.exists() {
if bail_non_exist {
bail!("Not found '{}'", entry_path.display());
} else {
return Ok(());
}
}
if entry_path.is_dir() {
let mut reader = tokio::fs::read_dir(entry_path).await?;
while let Some(entry) = reader.next_entry().await? {
let path = entry.path();
if path.is_dir() {
if !current_only {
if let Some(remaining_depth) = depth {
if remaining_depth > 0 {
list_files(
files,
&path,
suffixes,
current_only,
bail_non_exist,
Some(remaining_depth - 1),
)
.await?;
}
} else {
list_files(files, &path, suffixes, current_only, bail_non_exist, None)
.await?;
}
}
} else {
add_file(files, suffixes, &path);
}
}
} else {
add_file(files, suffixes, entry_path);
}
Ok(())
}
fn add_file(files: &mut IndexSet<String>, suffixes: Option<&Vec<String>>, path: &Path) {
if is_valid_extension(suffixes, path) {
let path = path.display().to_string();
if !files.contains(&path) {
files.insert(path);
}
}
}
fn is_valid_extension(suffixes: Option<&Vec<String>>, path: &Path) -> bool {
let filename_regex = Regex::new(r"^.+\.*").unwrap();
if let Some(suffixes) = suffixes {
if !suffixes.is_empty() {
if let Ok(Some(_)) = filename_regex.find(&suffixes.join(",")) {
let file_name = path
.file_name()
.and_then(|v| v.to_str())
.expect("invalid filename")
.to_string();
return suffixes.contains(&file_name);
} else if let Some(extension) =
path.extension().map(|v| v.to_string_lossy().to_string())
{
return suffixes.contains(&extension);
}
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_glob() {
assert_eq!(
parse_glob("dir").unwrap(),
("dir".into(), None, false, None)
);
assert_eq!(
parse_glob("dir/**").unwrap(),
("dir".into(), None, false, None)
);
assert_eq!(
parse_glob("dir/file.md").unwrap(),
("dir/file.md".into(), None, false, None)
);
assert_eq!(
parse_glob("**/*.md").unwrap(),
(".".into(), Some(vec!["md".into()]), false, None)
);
assert_eq!(
parse_glob("/**/*.md").unwrap(),
("/".into(), Some(vec!["md".into()]), false, None)
);
assert_eq!(
parse_glob("dir/**/*.md").unwrap(),
("dir".into(), Some(vec!["md".into()]), false, None)
);
assert_eq!(
parse_glob("dir/**/test.md").unwrap(),
("dir/".into(), Some(vec!["test.md".into()]), false, None)
);
assert_eq!(
parse_glob("dir/*/test.md").unwrap(),
(
"dir/".into(),
Some(vec!["test.md".into()]),
false,
Some(1usize)
)
);
assert_eq!(
parse_glob("dir/**/*.{md,txt}").unwrap(),
(
"dir".into(),
Some(vec!["md".into(), "txt".into()]),
false,
None
)
);
assert_eq!(
parse_glob("C:\\dir\\**\\*.{md,txt}").unwrap(),
(
"C:\\dir".into(),
Some(vec!["md".into(), "txt".into()]),
false,
None
)
);
assert_eq!(
parse_glob("*.md").unwrap(),
(".".into(), Some(vec!["md".into()]), true, None)
);
assert_eq!(
parse_glob("/*.md").unwrap(),
("/".into(), Some(vec!["md".into()]), true, None)
);
assert_eq!(
parse_glob("dir/*.md").unwrap(),
("dir".into(), Some(vec!["md".into()]), true, None)
);
assert_eq!(
parse_glob("dir/*.{md,txt}").unwrap(),
(
"dir".into(),
Some(vec!["md".into(), "txt".into()]),
true,
None
)
);
assert_eq!(
parse_glob("C:\\dir\\*.{md,txt}").unwrap(),
(
"C:\\dir".into(),
Some(vec!["md".into(), "txt".into()]),
true,
None
)
);
}
}
+155
View File
@@ -0,0 +1,155 @@
use std::collections::HashMap;
/// Render REPL prompt
///
/// The template comprises plain text and `{...}`.
///
/// The syntax of `{...}`:
/// - `{var}` - When `var` has a value, replace `var` with the value and eval `template`
/// - `{?var <template>}` - Eval `template` when `var` is evaluated as true
/// - `{!var <template>}` - Eval `template` when `var` is evaluated as false
pub fn render_prompt(template: &str, variables: &HashMap<&str, String>) -> String {
let exprs = parse_template(template);
eval_exprs(&exprs, variables)
}
fn parse_template(template: &str) -> Vec<Expr> {
let chars: Vec<char> = template.chars().collect();
let mut exprs = vec![];
let mut current = vec![];
let mut balances = vec![];
for ch in chars.iter().cloned() {
if !balances.is_empty() {
if ch == '}' {
balances.pop();
if balances.is_empty() {
if !current.is_empty() {
let block = parse_block(&mut current);
exprs.push(block)
}
} else {
current.push(ch);
}
} else if ch == '{' {
balances.push(ch);
current.push(ch);
} else {
current.push(ch);
}
} else if ch == '{' {
balances.push(ch);
add_text(&mut exprs, &mut current);
} else {
current.push(ch)
}
}
add_text(&mut exprs, &mut current);
exprs
}
fn parse_block(current: &mut Vec<char>) -> Expr {
let value: String = current.drain(..).collect();
match value.split_once(' ') {
Some((name, tail)) => {
if let Some(name) = name.strip_prefix('?') {
let block_exprs = parse_template(tail);
Expr::Block(BlockType::Yes, name.to_string(), block_exprs)
} else if let Some(name) = name.strip_prefix('!') {
let block_exprs = parse_template(tail);
Expr::Block(BlockType::No, name.to_string(), block_exprs)
} else {
Expr::Text(format!("{{{value}}}"))
}
}
None => Expr::Variable(value),
}
}
fn eval_exprs(exprs: &[Expr], variables: &HashMap<&str, String>) -> String {
let mut output = String::new();
for part in exprs {
match part {
Expr::Text(text) => output.push_str(text),
Expr::Variable(variable) => {
let value = variables
.get(variable.as_str())
.cloned()
.unwrap_or_default();
output.push_str(&value);
}
Expr::Block(typ, variable, block_exprs) => {
let value = variables
.get(variable.as_str())
.cloned()
.unwrap_or_default();
match typ {
BlockType::Yes => {
if truly(&value) {
let block_output = eval_exprs(block_exprs, variables);
output.push_str(&block_output)
}
}
BlockType::No => {
if !truly(&value) {
let block_output = eval_exprs(block_exprs, variables);
output.push_str(&block_output)
}
}
}
}
}
}
output
}
fn add_text(exprs: &mut Vec<Expr>, current: &mut Vec<char>) {
if current.is_empty() {
return;
}
let value: String = current.drain(..).collect();
exprs.push(Expr::Text(value));
}
fn truly(value: &str) -> bool {
!(value.is_empty() || value == "0" || value == "false")
}
#[derive(Debug)]
enum Expr {
Text(String),
Variable(String),
Block(BlockType, String, Vec<Expr>),
}
#[derive(Debug)]
enum BlockType {
Yes,
No,
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! assert_render {
($template:expr, [$(($key:literal, $value:literal),)*], $expect:literal) => {
let data = HashMap::from([
$(($key, $value.into()),)*
]);
assert_eq!(render_prompt($template, &data), $expect);
};
}
#[test]
fn test_render() {
let prompt = "{?session {session}{?role /}}{role}{?session )}{!session >}";
assert_render!(prompt, [], ">");
assert_render!(prompt, [("role", "coder"),], "coder>");
assert_render!(prompt, [("session", "temp"),], "temp)");
assert_render!(
prompt,
[("session", "temp"), ("role", "coder"),],
"temp/coder)"
);
}
}
+464
View File
@@ -0,0 +1,464 @@
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex;
use futures_util::{stream, StreamExt};
use http::header::CONTENT_TYPE;
use reqwest::Url;
use scraper::{Html, Selector};
use serde::Deserialize;
use serde_json::Value;
use std::sync::LazyLock;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Duration,
};
use tokio::io::AsyncWriteExt;
use tokio::sync::Semaphore;
pub const URL_LOADER: &str = "url";
pub const RECURSIVE_URL_LOADER: &str = "recursive_url";
pub const MEDIA_URL_EXTENSION: &str = "media_url";
pub const DEFAULT_EXTENSION: &str = "txt";
const MAX_CRAWLS: usize = 5;
const BREAK_ON_ERROR: bool = false;
const USER_AGENT: &str = "curl/8.6.0";
static CLIENT: LazyLock<Result<reqwest::Client>> = LazyLock::new(|| {
let builder = reqwest::ClientBuilder::new().timeout(Duration::from_secs(16));
let client = builder.build()?;
Ok(client)
});
static PRESET: LazyLock<Vec<(Regex, CrawlOptions)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"github.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap(),
CrawlOptions {
exclude: vec!["changelog".into(), "changes".into(), "license".into()],
..Default::default()
},
),
(
Regex::new(r"github.com/([^/]+)/([^/]+)/wiki").unwrap(),
CrawlOptions {
exclude: vec!["_history".into()],
extract: Some("#wiki-body".into()),
..Default::default()
},
),
]
});
static EXTENSION_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\.[^.]+$").unwrap());
static GITHUB_REPO_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)").unwrap());
pub async fn fetch(url: &str) -> Result<String> {
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let res = client.get(url).send().await?;
let output = res.text().await?;
Ok(output)
}
pub async fn fetch_with_loaders(
loaders: &HashMap<String, String>,
path: &str,
allow_media: bool,
) -> Result<(String, String)> {
if let Some(loader_command) = loaders.get(URL_LOADER) {
let contents = run_loader_command(path, URL_LOADER, loader_command)?;
return Ok((contents, DEFAULT_EXTENSION.into()));
}
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let mut res = client.get(path).send().await?;
if !res.status().is_success() {
bail!("Invalid status: {}", res.status());
}
let content_type = res
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|v| match v.split_once(';') {
Some((mime, _)) => mime.trim(),
None => v,
})
.map(|v| v.to_string())
.unwrap_or_else(|| {
format!(
"_/{}",
get_patch_extension(path).unwrap_or_else(|| DEFAULT_EXTENSION.into())
)
});
let mut is_media = false;
let extension = match content_type.as_str() {
"application/pdf" => "pdf".into(),
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => "docx".into(),
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => "xlsx".into(),
"application/vnd.openxmlformats-officedocument.presentationml.presentation" => {
"pptx".into()
}
"application/vnd.oasis.opendocument.text" => "odt".into(),
"application/vnd.oasis.opendocument.spreadsheet" => "ods".into(),
"application/vnd.oasis.opendocument.presentation" => "odp".into(),
"application/rtf" => "rtf".into(),
"text/javascript" => "js".into(),
"text/html" => "html".into(),
_ => content_type
.rsplit_once('/')
.map(|(first, last)| {
if ["image", "video", "audio"].contains(&first) {
is_media = true;
MEDIA_URL_EXTENSION.into()
} else {
last.to_lowercase()
}
})
.unwrap_or_else(|| DEFAULT_EXTENSION.into()),
};
let result = if is_media {
if !allow_media {
bail!("Unexpected media type")
}
let image_bytes = res.bytes().await?;
let image_base64 = base64_encode(&image_bytes);
let contents = format!("data:{content_type};base64,{image_base64}");
(contents, extension)
} else {
match loaders.get(&extension) {
Some(loader_command) => {
let save_path = temp_file("-download-", &format!(".{extension}"))
.display()
.to_string();
let mut save_file = tokio::fs::File::create(&save_path).await?;
let mut size = 0;
while let Some(chunk) = res.chunk().await? {
size += chunk.len();
save_file.write_all(&chunk).await?;
}
let contents = if size == 0 {
println!("{}", warning_text(&format!("No content at '{path}'")));
String::new()
} else {
run_loader_command(&save_path, &extension, loader_command)?
};
(contents, DEFAULT_EXTENSION.into())
}
None => {
let contents = res.text().await?;
if extension == "html" {
(html_to_md(&contents), "md".into())
} else {
(contents, extension)
}
}
}
};
Ok(result)
}
pub async fn fetch_models(api_base: &str, api_key: Option<&str>) -> Result<Vec<String>> {
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let mut builder = client.get(format!("{}/models", api_base.trim_end_matches('/')));
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
let res_body: Value = builder.send().await?.json().await?;
let mut result: Vec<String> = res_body
.get("data")
.and_then(|v| v.as_array())
.map(|v| {
v.iter()
.filter_map(|v| v.get("id").and_then(|v| v.as_str().map(|v| v.to_string())))
.collect()
})
.unwrap_or_default();
if result.is_empty() {
bail!("No valid models")
}
result.sort_unstable();
Ok(result)
}
#[derive(Debug, Clone, Default)]
pub struct CrawlOptions {
extract: Option<String>,
exclude: Vec<String>,
no_log: bool,
}
impl CrawlOptions {
pub fn preset(start_url: &str) -> CrawlOptions {
for (re, options) in PRESET.iter() {
if let Ok(true) = re.is_match(start_url) {
return options.clone();
}
}
CrawlOptions::default()
}
}
pub async fn crawl_website(start_url: &str, options: CrawlOptions) -> Result<Vec<Page>> {
let start_url = Url::parse(start_url)?;
let mut paths = vec![start_url.path().to_string()];
let normalized_start_url = normalize_start_url(&start_url);
if !options.no_log {
println!(
"Start crawling url={start_url} exclude={} extract={}",
options.exclude.join(","),
options.extract.as_deref().unwrap_or_default()
);
}
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
paths = crawl_gh_tree(&start_url, &options.exclude)
.await
.with_context(|| "Failed to craw github repo".to_string())?;
}
let semaphore = Arc::new(Semaphore::new(MAX_CRAWLS));
let mut result_pages = Vec::new();
let mut index = 0;
while index < paths.len() {
let batch = paths[index..std::cmp::min(index + MAX_CRAWLS, paths.len())].to_vec();
let tasks: Vec<_> = batch
.iter()
.map(|path| {
let options = options.clone();
let permit = semaphore.clone().acquire_owned(); // acquire a permit for concurrency control
let normalized_start_url = normalized_start_url.clone();
let path = path.clone();
async move {
let _permit = permit.await?;
let url = normalized_start_url
.join(&path)
.map_err(|_| anyhow!("Invalid crawl page at {}", path))?;
let mut page = crawl_page(&normalized_start_url, &path, options)
.await
.with_context(|| format!("Failed to crawl {}", url.as_str()))?;
page.0 = url.as_str().to_string();
Ok(page)
}
})
.collect();
let results = stream::iter(tasks)
.buffer_unordered(MAX_CRAWLS)
.collect::<Vec<_>>()
.await;
let mut new_paths = Vec::new();
for res in results {
match res {
Ok((path, text, links)) => {
if !options.no_log {
println!("Crawled {path}");
}
if !text.is_empty() {
result_pages.push(Page { path, text });
}
for link in links {
if !paths.iter().any(|p| match_link(p, &link)) {
new_paths.push(link);
}
}
}
Err(err) => {
if BREAK_ON_ERROR {
return Err(err);
} else if !options.no_log {
println!("{}", error_text(&pretty_error(&err)));
}
}
}
}
paths.extend(new_paths);
index += batch.len();
}
Ok(result_pages)
}
#[derive(Debug, Deserialize)]
pub struct Page {
pub path: String,
pub text: String,
}
async fn crawl_gh_tree(start_url: &Url, exclude: &[String]) -> Result<Vec<String>> {
let path_segs: Vec<&str> = start_url.path().split('/').collect();
if path_segs.len() < 4 {
bail!("Invalid gh tree {}", start_url.as_str());
}
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let owner = path_segs[1];
let repo = path_segs[2];
let branch = path_segs[4];
let root_path = path_segs[5..].join("/");
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/ref/heads/{branch}");
let res_body: Value = client
.get(&url)
.header("User-Agent", USER_AGENT)
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.send()
.await?
.json()
.await?;
let sha = res_body["object"]["sha"]
.as_str()
.ok_or_else(|| anyhow!("Not found branch or tag"))?;
let url = format!("https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}?recursive=true");
let res_body: Value = client
.get(&url)
.header("User-Agent", USER_AGENT)
.header("Accept", "application/vnd.github+json")
.header("X-GitHub-Api-Version", "2022-11-28")
.send()
.await?
.json()
.await?;
let tree = res_body["tree"]
.as_array()
.ok_or_else(|| anyhow!("Invalid github repo tree"))?;
let paths = tree
.iter()
.flat_map(|v| {
let typ = v["type"].as_str()?;
let path = v["path"].as_str()?;
if typ == "blob"
&& (path.ends_with(".md") || path.ends_with(".MD"))
&& path.starts_with(&root_path)
&& !should_exclude_link(path, exclude)
{
Some(format!(
"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
))
} else {
None
}
})
.collect();
Ok(paths)
}
async fn crawl_page(
start_url: &Url,
path: &str,
options: CrawlOptions,
) -> Result<(String, String, Vec<String>)> {
let client = match *CLIENT {
Ok(ref client) => client,
Err(ref err) => bail!("{err}"),
};
let location = start_url.join(path)?;
let response = client
.get(location.as_str())
.header("User-Agent", USER_AGENT)
.send()
.await?;
let body = response.text().await?;
if let Ok(true) = GITHUB_REPO_RE.is_match(start_url.as_str()) {
return Ok((path.to_string(), body, vec![]));
}
let mut links = HashSet::new();
let document = Html::parse_document(&body);
let selector = Selector::parse("a").map_err(|err| anyhow!("Invalid link selector, {}", err))?;
for element in document.select(&selector) {
if let Some(href) = element.value().attr("href") {
let href = Url::parse(href).ok().or_else(|| location.join(href).ok());
match href {
None => continue,
Some(href) => {
if href.as_str().starts_with(location.as_str())
&& !should_exclude_link(href.path(), &options.exclude)
{
links.insert(href.path().to_string());
}
}
}
}
}
let text = if let Some(selector) = &options.extract {
let selector = Selector::parse(selector)
.map_err(|err| anyhow!("Invalid extract selector, {}", err))?;
document
.select(&selector)
.map(|v| html_to_md(&v.html()))
.collect::<Vec<String>>()
.join("\n\n")
} else {
html_to_md(&body)
};
Ok((path.to_string(), text, links.into_iter().collect()))
}
fn should_exclude_link(link: &str, exclude: &[String]) -> bool {
if link.contains("#") {
return true;
}
let parts: Vec<&str> = link.trim_end_matches('/').split('/').collect();
let name = parts.last().unwrap_or(&"").to_lowercase();
for exclude_name in exclude {
let cond = match EXTENSION_RE.is_match(exclude_name) {
Ok(true) => exclude_name.to_lowercase() == name.to_lowercase(),
_ => exclude_name.to_lowercase() == EXTENSION_RE.replace(&name, "").to_lowercase(),
};
if cond {
return true;
}
}
false
}
fn normalize_start_url(start_url: &Url) -> Url {
let mut start_url = start_url.clone();
start_url.set_query(None);
start_url.set_fragment(None);
let new_path = match start_url.path().rfind('/') {
Some(last_slash_index) => start_url.path()[..last_slash_index + 1].to_string(),
None => start_url.path().to_string(),
};
start_url.set_path(&new_path);
start_url
}
fn match_link(path: &str, link: &str) -> bool {
path == link
|| path
== link
.trim_end_matches("/index.html")
.trim_end_matches("/index.htm")
}
+217
View File
@@ -0,0 +1,217 @@
use super::{poll_abort_signal, wait_abort_signal, AbortSignal, IS_STDOUT_TERMINAL};
use anyhow::{bail, Result};
use crossterm::{cursor, queue, style, terminal};
use std::{
future::Future,
io::{stdout, Write},
time::Duration,
};
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver},
oneshot,
},
time::interval,
};
#[derive(Debug, Default)]
pub struct SpinnerInner {
index: usize,
message: String,
}
impl SpinnerInner {
const DATA: [&'static str; 10] = ["", "", "", "", "", "", "", "", "", ""];
fn step(&mut self) -> Result<()> {
if !*IS_STDOUT_TERMINAL || self.message.is_empty() {
return Ok(());
}
let mut writer = stdout();
let frame = Self::DATA[self.index % Self::DATA.len()];
let dots = ".".repeat((self.index / 5) % 4);
let line = format!("{frame}{}{:<3}", self.message, dots);
queue!(writer, cursor::MoveToColumn(0), style::Print(line),)?;
if self.index == 0 {
queue!(writer, cursor::Hide)?;
}
writer.flush()?;
self.index += 1;
Ok(())
}
fn set_message(&mut self, message: String) -> Result<()> {
self.clear_message()?;
if !message.is_empty() {
self.message = format!(" {message}");
}
Ok(())
}
fn clear_message(&mut self) -> Result<()> {
if !*IS_STDOUT_TERMINAL || self.message.is_empty() {
return Ok(());
}
self.message.clear();
let mut writer = stdout();
queue!(
writer,
cursor::MoveToColumn(0),
terminal::Clear(terminal::ClearType::FromCursorDown),
cursor::Show
)?;
writer.flush()?;
Ok(())
}
}
#[derive(Clone)]
pub struct Spinner(mpsc::UnboundedSender<SpinnerEvent>);
impl Spinner {
pub fn create(message: &str) -> (Self, UnboundedReceiver<SpinnerEvent>) {
let (tx, spinner_rx) = mpsc::unbounded_channel();
let spinner = Spinner(tx);
let _ = spinner.set_message(message.to_string());
(spinner, spinner_rx)
}
pub fn set_message(&self, message: String) -> Result<()> {
self.0.send(SpinnerEvent::SetMessage(message))?;
std::thread::sleep(Duration::from_millis(10));
Ok(())
}
pub fn stop(&self) {
let _ = self.0.send(SpinnerEvent::Stop);
std::thread::sleep(Duration::from_millis(10));
}
}
pub enum SpinnerEvent {
SetMessage(String),
Stop,
}
pub fn spawn_spinner(message: &str) -> Spinner {
let (spinner, mut spinner_rx) = Spinner::create(message);
tokio::spawn(async move {
let mut spinner = SpinnerInner::default();
let mut interval = interval(Duration::from_millis(50));
loop {
tokio::select! {
evt = spinner_rx.recv() => {
if let Some(evt) = evt {
match evt {
SpinnerEvent::SetMessage(message) => {
spinner.set_message(message)?;
}
SpinnerEvent::Stop => {
spinner.clear_message()?;
break;
}
}
}
}
_ = interval.tick() => {
let _ = spinner.step();
}
}
}
Ok::<(), anyhow::Error>(())
});
spinner
}
pub async fn abortable_run_with_spinner<F, T>(
task: F,
message: &str,
abort_signal: AbortSignal,
) -> Result<T>
where
F: Future<Output = Result<T>>,
{
let (_, spinner_rx) = Spinner::create(message);
abortable_run_with_spinner_rx(task, spinner_rx, abort_signal).await
}
pub async fn abortable_run_with_spinner_rx<F, T>(
task: F,
spinner_rx: UnboundedReceiver<SpinnerEvent>,
abort_signal: AbortSignal,
) -> Result<T>
where
F: Future<Output = Result<T>>,
{
if *IS_STDOUT_TERMINAL {
let (done_tx, done_rx) = oneshot::channel();
let run_task = async {
tokio::select! {
ret = task => {
let _ = done_tx.send(());
ret
}
_ = tokio::signal::ctrl_c() => {
abort_signal.set_ctrlc();
let _ = done_tx.send(());
bail!("Aborted!")
},
_ = wait_abort_signal(&abort_signal) => {
let _ = done_tx.send(());
bail!("Aborted.");
},
}
};
let (task_ret, spinner_ret) = tokio::join!(
run_task,
run_abortable_spinner(spinner_rx, done_rx, abort_signal.clone())
);
spinner_ret?;
task_ret
} else {
task.await
}
}
async fn run_abortable_spinner(
mut spinner_rx: UnboundedReceiver<SpinnerEvent>,
mut done_rx: oneshot::Receiver<()>,
abort_signal: AbortSignal,
) -> Result<()> {
let mut spinner = SpinnerInner::default();
loop {
if abort_signal.aborted() {
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
match done_rx.try_recv() {
Ok(_) | Err(oneshot::error::TryRecvError::Closed) => {
break;
}
_ => {}
}
match spinner_rx.try_recv() {
Ok(SpinnerEvent::SetMessage(message)) => {
spinner.set_message(message)?;
}
Ok(SpinnerEvent::Stop) => {
spinner.clear_message()?;
}
Err(_) => {}
}
if poll_abort_signal(&abort_signal)? {
break;
}
spinner.step()?;
}
spinner.clear_message()?;
Ok(())
}
+32
View File
@@ -0,0 +1,32 @@
use super::*;
use fancy_regex::{Captures, Regex};
use std::sync::LazyLock;
pub static RE_VARIABLE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{(\w+)\}\}").unwrap());
pub fn interpolate_variables(text: &mut String) {
*text = RE_VARIABLE
.replace_all(text, |caps: &Captures<'_>| {
let key = &caps[1];
match key {
"__os__" => env::consts::OS.to_string(),
"__os_distro__" => {
let info = os_info::get();
if env::consts::OS == "linux" {
format!("{info} (linux)")
} else {
info.to_string()
}
}
"__os_family__" => env::consts::FAMILY.to_string(),
"__arch__" => env::consts::ARCH.to_string(),
"__shell__" => SHELL.name.clone(),
"__locale__" => sys_locale::get_locale().unwrap_or_default(),
"__now__" => now(),
"__cwd__" => env::current_dir()
.map(|v| v.display().to_string())
.unwrap_or_default(),
_ => format!("{{{{{key}}}}}"),
}
})
.to_string();
}