refactor: Updated to the most recent Rust version with 2024 syntax
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{Result, anyhow};
|
||||
use chrono::Utc;
|
||||
use indexmap::IndexMap;
|
||||
use parking_lot::RwLock;
|
||||
|
||||
+31
-29
@@ -2,7 +2,7 @@ use super::*;
|
||||
|
||||
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use aws_smithy_eventstream::smithy::parse_response_headers;
|
||||
use bytes::BytesMut;
|
||||
@@ -11,7 +11,7 @@ use futures_util::StreamExt;
|
||||
use indexmap::IndexMap;
|
||||
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BedrockConfig {
|
||||
@@ -222,29 +222,29 @@ async fn chat_completions_streaming(
|
||||
debug!("stream-data: {smithy_type} {data}");
|
||||
match smithy_type {
|
||||
"contentBlockStart" => {
|
||||
if let Some(tool_use) = data["start"]["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name)) = (
|
||||
if let Some(tool_use) = data["start"]["toolUse"].as_object()
|
||||
&& let (Some(id), Some(name)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
) {
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value =
|
||||
)
|
||||
{
|
||||
if !function_name.is_empty() {
|
||||
if function_arguments.is_empty() {
|
||||
function_arguments = String::from("{}");
|
||||
}
|
||||
let arguments: Value =
|
||||
function_arguments.parse().with_context(|| {
|
||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
||||
})?;
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_arguments.clear();
|
||||
function_name = name.into();
|
||||
function_id = id.into();
|
||||
handler.tool_call(ToolCall::new(
|
||||
function_name.clone(),
|
||||
arguments,
|
||||
Some(function_id.clone()),
|
||||
))?;
|
||||
}
|
||||
function_arguments.clear();
|
||||
function_name = name.into();
|
||||
function_id = id.into();
|
||||
}
|
||||
}
|
||||
"contentBlockDelta" => {
|
||||
@@ -291,7 +291,9 @@ async fn chat_completions_streaming(
|
||||
bail!("Invalid response data: {data} (smithy_type: {smithy_type})")
|
||||
}
|
||||
_ => {
|
||||
bail!("Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",);
|
||||
bail!(
|
||||
"Unrecognized message, message_type: {message_type}, smithy_type: {smithy_type}",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -494,18 +496,18 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
if let Some(text) = json_str_from_map(reasoning_text, "text") {
|
||||
reasoning = Some(text.to_string());
|
||||
}
|
||||
} else if let Some(tool_use) = item["toolUse"].as_object() {
|
||||
if let (Some(id), Some(name), Some(input)) = (
|
||||
} else if let Some(tool_use) = item["toolUse"].as_object()
|
||||
&& let (Some(id), Some(name), Some(input)) = (
|
||||
json_str_from_map(tool_use, "toolUseId"),
|
||||
json_str_from_map(tool_use, "name"),
|
||||
tool_use.get("input"),
|
||||
) {
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
))
|
||||
}
|
||||
)
|
||||
{
|
||||
tool_calls.push(ToolCall::new(
|
||||
name.to_string(),
|
||||
input.clone(),
|
||||
Some(id.to_string()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+10
-10
@@ -2,10 +2,10 @@ use super::openai::*;
|
||||
use super::openai_compatible::*;
|
||||
use super::*;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
const API_BASE: &str = "https://api.cohere.ai/v2";
|
||||
|
||||
@@ -49,10 +49,10 @@ fn prepare_chat_completions(
|
||||
|
||||
let url = format!("{}/chat", api_base.trim_end_matches('/'));
|
||||
let mut body = openai_build_chat_completions_body(data, &self_.model);
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
if let Some(top_p) = obj.remove("top_p") {
|
||||
obj.insert("p".to_string(), top_p);
|
||||
}
|
||||
if let Some(obj) = body.as_object_mut()
|
||||
&& let Some(top_p) = obj.remove("top_p")
|
||||
{
|
||||
obj.insert("p".to_string(), top_p);
|
||||
}
|
||||
|
||||
let mut request_data = RequestData::new(url, body);
|
||||
@@ -218,10 +218,10 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
||||
|
||||
let mut tool_calls = vec![];
|
||||
if let Some(calls) = data["message"]["tool_calls"].as_array() {
|
||||
if text.is_empty() {
|
||||
if let Some(tool_plain) = data["message"]["tool_plan"].as_str() {
|
||||
text = tool_plain.to_string();
|
||||
}
|
||||
if text.is_empty()
|
||||
&& let Some(tool_plain) = data["message"]["tool_plan"].as_str()
|
||||
{
|
||||
text = tool_plain.to_string();
|
||||
}
|
||||
for call in calls {
|
||||
if let (Some(name), Some(arguments), Some(id)) = (
|
||||
|
||||
@@ -2,21 +2,21 @@ use super::*;
|
||||
|
||||
use crate::{
|
||||
config::{Config, GlobalConfig, Input},
|
||||
function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult},
|
||||
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
|
||||
render::render_stream,
|
||||
utils::*,
|
||||
};
|
||||
|
||||
use crate::vault::Vault;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use fancy_regex::Regex;
|
||||
use indexmap::IndexMap;
|
||||
use inquire::{
|
||||
list_option::ListOption, required, validator::Validation, MultiSelect, Select, Text,
|
||||
MultiSelect, Select, Text, list_option::ListOption, required, validator::Validation,
|
||||
};
|
||||
use reqwest::{Client as ReqwestClient, RequestBuilder};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::unbounded_channel;
|
||||
@@ -180,11 +180,11 @@ pub trait Client: Sync + Send {
|
||||
};
|
||||
for (key, patch) in patch_map {
|
||||
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
|
||||
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
|
||||
if let Ok(true) = regex.is_match(self.model().name()) {
|
||||
request_data.apply_patch(patch);
|
||||
return;
|
||||
}
|
||||
if let Ok(regex) = Regex::new(&format!("^({key})$"))
|
||||
&& let Ok(true) = regex.is_match(self.model().name())
|
||||
{
|
||||
request_data.apply_patch(patch);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,10 +119,10 @@ impl MessageContent {
|
||||
}
|
||||
for tool_result in tool_results {
|
||||
let mut parts = vec!["Call".to_string()];
|
||||
if let Some((agent_name, functions)) = agent_info {
|
||||
if functions.contains(&tool_result.call.name) {
|
||||
parts.push(agent_name.clone())
|
||||
}
|
||||
if let Some((agent_name, functions)) = agent_info
|
||||
&& functions.contains(&tool_result.call.name)
|
||||
{
|
||||
parts.push(agent_name.clone())
|
||||
}
|
||||
parts.push(tool_result.call.name.clone());
|
||||
parts.push(tool_result.call.arguments.to_string());
|
||||
|
||||
+6
-7
@@ -1,13 +1,12 @@
|
||||
use super::{
|
||||
list_all_models, list_client_names,
|
||||
ApiPatch, MessageContentToolCalls, RequestPatch, list_all_models, list_client_names,
|
||||
message::{Message, MessageContent, MessageContentPart},
|
||||
ApiPatch, MessageContentToolCalls, RequestPatch,
|
||||
};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::utils::{estimate_token_length, strip_think_tag};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use anyhow::{Result, bail};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::fmt::Display;
|
||||
@@ -275,10 +274,10 @@ impl Model {
|
||||
|
||||
pub fn guard_max_input_tokens(&self, messages: &[Message]) -> Result<()> {
|
||||
let total_tokens = self.total_tokens(messages) + BASIS_TOKENS;
|
||||
if let Some(max_input_tokens) = self.data.max_input_tokens {
|
||||
if total_tokens >= max_input_tokens {
|
||||
bail!("Exceed max_input_tokens limit")
|
||||
}
|
||||
if let Some(max_input_tokens) = self.data.max_input_tokens
|
||||
&& total_tokens >= max_input_tokens
|
||||
{
|
||||
bail!("Exceed max_input_tokens limit")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::*;
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAICompatibleConfig {
|
||||
@@ -124,12 +124,12 @@ pub async fn generic_rerank(builder: RequestBuilder, _model: &Model) -> Result<R
|
||||
if !status.is_success() {
|
||||
catch_error(&data, status.as_u16())?;
|
||||
}
|
||||
if data.get("results").is_none() && data.get("data").is_some() {
|
||||
if let Some(data_obj) = data.as_object_mut() {
|
||||
if let Some(value) = data_obj.remove("data") {
|
||||
data_obj.insert("results".to_string(), value);
|
||||
}
|
||||
}
|
||||
if data.get("results").is_none()
|
||||
&& data.get("data").is_some()
|
||||
&& let Some(data_obj) = data.as_object_mut()
|
||||
&& let Some(value) = data_obj.remove("data")
|
||||
{
|
||||
data_obj.insert("results".to_string(), value);
|
||||
}
|
||||
let res_body: GenericRerankResBody =
|
||||
serde_json::from_value(data).context("Invalid rerank data")?;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{catch_error, ToolCall};
|
||||
use super::{ToolCall, catch_error};
|
||||
use crate::utils::AbortSignal;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
||||
@@ -214,11 +214,11 @@ impl JsonStreamParser {
|
||||
}
|
||||
'}' => {
|
||||
self.balances.pop();
|
||||
if self.balances.is_empty() {
|
||||
if let Some(start) = self.start.take() {
|
||||
let value: String = self.buffer[start..=i].iter().collect();
|
||||
handle(&value)?;
|
||||
}
|
||||
if self.balances.is_empty()
|
||||
&& let Some(start) = self.start.take()
|
||||
{
|
||||
let value: String = self.buffer[start..=i].iter().collect();
|
||||
handle(&value)?;
|
||||
}
|
||||
}
|
||||
']' => {
|
||||
|
||||
Reference in New Issue
Block a user