refactor: Updated to the most recent Rust version with 2024 syntax

This commit is contained in:
2025-11-07 15:50:55 -07:00
parent 667c843fc0
commit 14549afd52
44 changed files with 377 additions and 371 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use chrono::Utc;
use indexmap::IndexMap;
use parking_lot::RwLock;
+31 -29
View File
@@ -2,7 +2,7 @@ use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
@@ -11,7 +11,7 @@ use futures_util::StreamExt;
use indexmap::IndexMap;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
@@ -222,29 +222,29 @@ async fn chat_completions_streaming(
debug!("stream-data: {smithy_type} {data}");
match smithy_type {
"contentBlockStart" => {
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
if let (Some(id), Some(name)) = (
if let Some(tool_use) = data["start"]["toolUse"].as_object()
&& let (Some(id), Some(name)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
) {
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value =
)
{
if !function_name.is_empty() {
if function_arguments.is_empty() {
function_arguments = String::from("{}");
}
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_arguments.clear();
function_name = name.into();
function_id = id.into();
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
Some(function_id.clone()),
))?;
}
function_arguments.clear();
function_name = name.into();
function_id = id.into();
}
}
"contentBlockDelta" => {
@@ -291,7 +291,9 @@ async fn chat_completions_streaming(
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
}
_ => {
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
bail!(
"Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",
);
}
}
}
@@ -494,18 +496,18 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
if let Some(text) = json_str_from_map(reasoning_text, "text") {
reasoning = Some(text.to_string());
}
} else if let Some(tool_use) = item["toolUse"].as_object() {
if let (Some(id), Some(name), Some(input)) = (
} else if let Some(tool_use) = item["toolUse"].as_object()
&& let (Some(id), Some(name), Some(input)) = (
json_str_from_map(tool_use, "toolUseId"),
json_str_from_map(tool_use, "name"),
tool_use.get("input"),
) {
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
}
)
{
tool_calls.push(ToolCall::new(
name.to_string(),
input.clone(),
Some(id.to_string()),
))
}
}
}
+10 -10
View File
@@ -2,10 +2,10 @@ use super::openai::*;
use super::openai_compatible::*;
use super::*;
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
const API_BASE: &str = "https://api.cohere.ai/v2";
@@ -49,10 +49,10 @@ fn prepare_chat_completions(
let url = format!("{}/chat", api_base.trim_end_matches('/'));
let mut body = openai_build_chat_completions_body(data, &self_.model);
if let Some(obj) = body.as_object_mut() {
if let Some(top_p) = obj.remove("top_p") {
obj.insert("p".to_string(), top_p);
}
if let Some(obj) = body.as_object_mut()
&& let Some(top_p) = obj.remove("top_p")
{
obj.insert("p".to_string(), top_p);
}
let mut request_data = RequestData::new(url, body);
@@ -218,10 +218,10 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let mut tool_calls = vec![];
if let Some(calls) = data["message"]["tool_calls"].as_array() {
if text.is_empty() {
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
text = tool_plain.to_string();
}
if text.is_empty()
&& let Some(tool_plain) = data["message"]["tool_plan"].as_str()
{
text = tool_plain.to_string();
}
for call in calls {
if let (Some(name), Some(arguments), Some(id)) = (
+9 -9
View File
@@ -2,21 +2,21 @@ use super::*;
use crate::{
config::{Config, GlobalConfig, Input},
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
render::render_stream,
utils::*,
};
use crate::vault::Vault;
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use fancy_regex::Regex;
use indexmap::IndexMap;
use inquire::{
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
MultiSelect, Select, Text, list_option::ListOption, required, validator::Validation,
};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::mpsc::unbounded_channel;
@@ -180,11 +180,11 @@ pub trait Client: Sync + Send {
};
for (key, patch) in patch_map {
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
if let Ok(true) = regex.is_match(self.model().name()) {
request_data.apply_patch(patch);
return;
}
if let Ok(regex) = Regex::new(&format!("^({key})$"))
&& let Ok(true) = regex.is_match(self.model().name())
{
request_data.apply_patch(patch);
return;
}
}
}
+4 -4
View File
@@ -119,10 +119,10 @@ impl MessageContent {
}
for tool_result in tool_results {
let mut parts = vec!["Call".to_string()];
if let Some((agent_name, functions)) = agent_info {
if functions.contains(&tool_result.call.name) {
parts.push(agent_name.clone())
}
if let Some((agent_name, functions)) = agent_info
&& functions.contains(&tool_result.call.name)
{
parts.push(agent_name.clone())
}
parts.push(tool_result.call.name.clone());
parts.push(tool_result.call.arguments.to_string());
+6 -7
View File
@@ -1,13 +1,12 @@
use super::{
list_all_models, list_client_names,
ApiPatch, MessageContentToolCalls, RequestPatch, list_all_models, list_client_names,
message::{Message, MessageContent, MessageContentPart},
ApiPatch, MessageContentToolCalls, RequestPatch,
};
use crate::config::Config;
use crate::utils::{estimate_token_length, strip_think_tag};
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Display;
@@ -275,10 +274,10 @@ impl Model {
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
if let Some(max_input_tokens) = self.data.max_input_tokens {
if total_tokens >= max_input_tokens {
bail!("Exceed max_input_tokens limit")
}
if let Some(max_input_tokens) = self.data.max_input_tokens
&& total_tokens >= max_input_tokens
{
bail!("Exceed max_input_tokens limit")
}
Ok(())
}
+7 -7
View File
@@ -4,7 +4,7 @@ use super::*;
use anyhow::{Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAICompatibleConfig {
@@ -124,12 +124,12 @@ pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<R
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
if data.get("results").is_none() && data.get("data").is_some() {
if let Some(data_obj) = data.as_object_mut() {
if let Some(value) = data_obj.remove("data") {
data_obj.insert("results".to_string(), value);
}
}
if data.get("results").is_none()
&& data.get("data").is_some()
&& let Some(data_obj) = data.as_object_mut()
&& let Some(value) = data_obj.remove("data")
{
data_obj.insert("results".to_string(), value);
}
let res_body: GenericRerankResBody =
serde_json::from_value(data).context("Invalid rerank data")?;
+7 -7
View File
@@ -1,7 +1,7 @@
use super::{catch_error, ToolCall};
use super::{ToolCall, catch_error};
use crate::utils::AbortSignal;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{Context, Result, anyhow, bail};
use futures_util::{Stream, StreamExt};
use reqwest::RequestBuilder;
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
@@ -214,11 +214,11 @@ impl JsonStreamParser {
}
'}' => {
self.balances.pop();
if self.balances.is_empty() {
if let Some(start) = self.start.take() {
let value: String = self.buffer[start..=i].iter().collect();
handle(&value)?;
}
if self.balances.is_empty()
&& let Some(start) = self.start.take()
{
let value: String = self.buffer[start..=i].iter().collect();
handle(&value)?;
}
}
']' => {