refactor: python tools now use tree-sitter queries instead of AST
This commit is contained in:
+705
-192
@@ -1,13 +1,11 @@
|
||||
use crate::function::{FunctionDeclaration, JsonSchema};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use ast::{Stmt, StmtFunctionDef};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use indexmap::IndexMap;
|
||||
use rustpython_ast::{Constant, Expr, UnaryOp};
|
||||
use rustpython_parser::{Mode, ast};
|
||||
use serde_json::Value;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use tree_sitter::{Node, Parser, Tree};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Param {
|
||||
@@ -28,178 +26,197 @@ pub fn generate_python_declarations(
|
||||
tool_file
|
||||
.read_to_string(&mut src)
|
||||
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
|
||||
let suite = parse_suite(&src, file_name)?;
|
||||
let tree = parse_tree(&src, file_name)?;
|
||||
|
||||
let is_tool = parent
|
||||
.and_then(|p| p.file_name())
|
||||
.is_some_and(|n| n == "tools");
|
||||
let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?;
|
||||
|
||||
if is_tool {
|
||||
for d in &mut declarations {
|
||||
d.agent = true;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(declarations)
|
||||
python_to_function_declarations(file_name, &src, &tree, is_tool)
|
||||
}
|
||||
|
||||
fn parse_suite(src: &str, filename: &str) -> Result<ast::Suite> {
|
||||
let mod_ast =
|
||||
rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?;
|
||||
fn parse_tree(src: &str, filename: &str) -> Result<Tree> {
|
||||
let mut parser = Parser::new();
|
||||
let language = tree_sitter_python::LANGUAGE.into();
|
||||
parser
|
||||
.set_language(&language)
|
||||
.context("failed to initialize python tree-sitter parser")?;
|
||||
|
||||
let suite = match mod_ast {
|
||||
ast::Mod::Module(m) => m.body,
|
||||
ast::Mod::Interactive(m) => m.body,
|
||||
ast::Mod::Expression(_) => bail!("expected a module; got a single expression"),
|
||||
_ => bail!("unexpected parse mode/AST variant"),
|
||||
};
|
||||
let tree = parser
|
||||
.parse(src.as_bytes(), None)
|
||||
.ok_or_else(|| anyhow!("failed to parse python: {filename}"))?;
|
||||
|
||||
Ok(suite)
|
||||
if tree.root_node().has_error() {
|
||||
bail!("failed to parse python: syntax error in {filename}");
|
||||
}
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
fn python_to_function_declarations(
|
||||
file_name: &str,
|
||||
module: &ast::Suite,
|
||||
src: &str,
|
||||
tree: &Tree,
|
||||
is_tool: bool,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let mut out = Vec::new();
|
||||
let root = tree.root_node();
|
||||
let mut cursor = root.walk();
|
||||
|
||||
for stmt in module {
|
||||
if let Stmt::FunctionDef(fd) = stmt {
|
||||
let func_name = fd.name.to_string();
|
||||
for stmt in root.named_children(&mut cursor) {
|
||||
let Some(fd) = unwrap_function_definition(stmt) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if func_name.starts_with('_') && func_name != "_instructions" {
|
||||
continue;
|
||||
}
|
||||
let func_name = function_name(fd, src)?.to_string();
|
||||
|
||||
if is_tool && func_name != "run" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let description = get_docstring_from_body(&fd.body).unwrap_or_default();
|
||||
let params = collect_params(fd);
|
||||
let schema = build_parameters_schema(¶ms, &description);
|
||||
let name = if is_tool && func_name == "run" {
|
||||
underscore(file_name)
|
||||
} else {
|
||||
underscore(&func_name)
|
||||
};
|
||||
let desc_trim = description.trim().to_string();
|
||||
if desc_trim.is_empty() {
|
||||
bail!("Missing or empty description on function: {func_name}");
|
||||
}
|
||||
|
||||
out.push(FunctionDeclaration {
|
||||
name,
|
||||
description: desc_trim,
|
||||
parameters: schema,
|
||||
agent: !is_tool,
|
||||
});
|
||||
if func_name.starts_with('_') && func_name != "_instructions" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_tool && func_name != "run" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let description = get_docstring_from_function(fd, src).unwrap_or_default();
|
||||
let params = collect_params(fd, src)?;
|
||||
let schema = build_parameters_schema(¶ms, &description);
|
||||
let name = if is_tool && func_name == "run" {
|
||||
underscore(file_name)
|
||||
} else {
|
||||
underscore(&func_name)
|
||||
};
|
||||
let desc_trim = description.trim().to_string();
|
||||
if desc_trim.is_empty() {
|
||||
bail!("Missing or empty description on function: {func_name}");
|
||||
}
|
||||
|
||||
out.push(FunctionDeclaration {
|
||||
name,
|
||||
description: desc_trim,
|
||||
parameters: schema,
|
||||
agent: !is_tool,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn get_docstring_from_body(body: &[Stmt]) -> Option<String> {
|
||||
let first = body.first()?;
|
||||
if let Stmt::Expr(expr_stmt) = first
|
||||
&& let Expr::Constant(constant) = &*expr_stmt.value
|
||||
&& let Constant::Str(s) = &constant.value
|
||||
{
|
||||
return Some(s.clone());
|
||||
fn unwrap_function_definition(node: Node<'_>) -> Option<Node<'_>> {
|
||||
match node.kind() {
|
||||
"function_definition" => Some(node),
|
||||
"decorated_definition" => {
|
||||
let mut cursor = node.walk();
|
||||
node.named_children(&mut cursor)
|
||||
.find(|child| child.kind() == "function_definition")
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
|
||||
fn function_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
|
||||
let name_node = node
|
||||
.child_by_field_name("name")
|
||||
.ok_or_else(|| anyhow!("function_definition missing name"))?;
|
||||
node_text(name_node, src)
|
||||
}
|
||||
|
||||
fn get_docstring_from_function(node: Node<'_>, src: &str) -> Option<String> {
|
||||
let body = node.child_by_field_name("body")?;
|
||||
let mut cursor = body.walk();
|
||||
let first = body.named_children(&mut cursor).next()?;
|
||||
if first.kind() != "expression_statement" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut expr_cursor = first.walk();
|
||||
let expr = first.named_children(&mut expr_cursor).next()?;
|
||||
if expr.kind() != "string" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text = node_text(expr, src).ok()?;
|
||||
strip_string_quotes(text)
|
||||
}
|
||||
|
||||
fn strip_string_quotes(text: &str) -> Option<String> {
|
||||
let quote_offset = text
|
||||
.char_indices()
|
||||
.find_map(|(idx, ch)| (ch == '\'' || ch == '"').then_some(idx))?;
|
||||
let prefix = &text[..quote_offset];
|
||||
if !prefix.chars().all(|ch| ch.is_ascii_alphabetic()) {
|
||||
return None;
|
||||
}
|
||||
if prefix.chars().any(|ch| ch == 'f' || ch == 'F') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let literal = &text[quote_offset..];
|
||||
let quote = if literal.starts_with("\"\"\"") {
|
||||
"\"\"\""
|
||||
} else if literal.starts_with("'''") {
|
||||
"'''"
|
||||
} else if literal.starts_with('"') {
|
||||
"\""
|
||||
} else if literal.starts_with('\'') {
|
||||
"'"
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
if literal.len() < quote.len() * 2 || !literal.ends_with(quote) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(literal[quote.len()..literal.len() - quote.len()].to_string())
|
||||
}
|
||||
|
||||
fn collect_params(node: Node<'_>, src: &str) -> Result<Vec<Param>> {
|
||||
let parameters = node
|
||||
.child_by_field_name("parameters")
|
||||
.ok_or_else(|| anyhow!("function_definition missing parameters"))?;
|
||||
let mut out = Vec::new();
|
||||
let mut cursor = parameters.walk();
|
||||
|
||||
for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) {
|
||||
let name = a.def.arg.to_string();
|
||||
let mut ty = get_arg_type(a.def.annotation.as_deref());
|
||||
let mut required = a.default.is_none();
|
||||
|
||||
if ty.ends_with('?') {
|
||||
ty.pop();
|
||||
required = false;
|
||||
for param in parameters.named_children(&mut cursor) {
|
||||
match param.kind() {
|
||||
"identifier" => out.push(Param {
|
||||
name: node_text(param, src)?.to_string(),
|
||||
ty_hint: String::new(),
|
||||
required: true,
|
||||
default: None,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
}),
|
||||
"typed_parameter" => out.push(build_param(
|
||||
parameter_name(param, src)?,
|
||||
get_arg_type(param.child_by_field_name("type"), src)?,
|
||||
true,
|
||||
None,
|
||||
)),
|
||||
"default_parameter" => out.push(build_param(
|
||||
parameter_name(param, src)?,
|
||||
String::new(),
|
||||
false,
|
||||
Some(Value::Null),
|
||||
)),
|
||||
"typed_default_parameter" => out.push(build_param(
|
||||
parameter_name(param, src)?,
|
||||
get_arg_type(param.child_by_field_name("type"), src)?,
|
||||
false,
|
||||
Some(Value::Null),
|
||||
)),
|
||||
"list_splat_pattern" | "dictionary_splat_pattern" | "positional_separator" => {
|
||||
bail!(
|
||||
"Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"
|
||||
)
|
||||
}
|
||||
"keyword_separator" => continue,
|
||||
other => bail!("Unsupported parameter type: {other}"),
|
||||
}
|
||||
|
||||
let default = if a.default.is_some() {
|
||||
Some(Value::Null)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required,
|
||||
default,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
for a in &fd.args.kwonlyargs {
|
||||
let name = a.def.arg.to_string();
|
||||
let mut ty = get_arg_type(a.def.annotation.as_deref());
|
||||
let mut required = a.default.is_none();
|
||||
|
||||
if ty.ends_with('?') {
|
||||
ty.pop();
|
||||
required = false;
|
||||
}
|
||||
|
||||
let default = if a.default.is_some() {
|
||||
Some(Value::Null)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required,
|
||||
default,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(vararg) = &fd.args.vararg {
|
||||
let name = vararg.arg.to_string();
|
||||
let inner = get_arg_type(vararg.annotation.as_deref());
|
||||
let ty = if inner.is_empty() {
|
||||
"list[str]".into()
|
||||
} else {
|
||||
format!("list[{inner}]")
|
||||
};
|
||||
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: ty,
|
||||
required: false,
|
||||
default: None,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &fd.args.kwarg {
|
||||
let name = kwarg.arg.to_string();
|
||||
out.push(Param {
|
||||
name,
|
||||
ty_hint: "object".into(),
|
||||
required: false,
|
||||
default: None,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(doc) = get_docstring_from_body(&fd.body) {
|
||||
if let Some(doc) = get_docstring_from_function(node, src) {
|
||||
let meta = parse_docstring_args(&doc);
|
||||
for p in &mut out {
|
||||
if let Some((t, d)) = meta.get(&p.name) {
|
||||
@@ -218,69 +235,155 @@ fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn get_arg_type(annotation: Option<&Expr>) -> String {
|
||||
match annotation {
|
||||
None => "".to_string(),
|
||||
Some(Expr::Name(n)) => n.id.to_string(),
|
||||
Some(Expr::Subscript(sub)) => match &*sub.value {
|
||||
Expr::Name(name) if &name.id == "Optional" => {
|
||||
let inner = get_arg_type(Some(&sub.slice));
|
||||
format!("{inner}?")
|
||||
}
|
||||
Expr::Name(name) if &name.id == "List" => {
|
||||
let inner = get_arg_type(Some(&sub.slice));
|
||||
format!("list[{inner}]")
|
||||
}
|
||||
Expr::Name(name) if &name.id == "Literal" => {
|
||||
let vals = literal_members(&sub.slice);
|
||||
format!("literal:{}", vals.join("|"))
|
||||
}
|
||||
_ => "any".to_string(),
|
||||
},
|
||||
_ => "any".to_string(),
|
||||
fn build_param(name: &str, mut ty: String, mut required: bool, default: Option<Value>) -> Param {
|
||||
if ty.ends_with('?') {
|
||||
ty.pop();
|
||||
required = false;
|
||||
}
|
||||
|
||||
Param {
|
||||
name: name.to_string(),
|
||||
ty_hint: ty,
|
||||
required,
|
||||
default,
|
||||
doc_type: None,
|
||||
doc_desc: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_to_str(e: &Expr) -> String {
|
||||
match e {
|
||||
Expr::Constant(c) => match &c.value {
|
||||
Constant::Str(s) => s.clone(),
|
||||
Constant::Int(i) => i.to_string(),
|
||||
Constant::Float(f) => f.to_string(),
|
||||
Constant::Bool(b) => b.to_string(),
|
||||
Constant::None => "None".to_string(),
|
||||
Constant::Ellipsis => "...".to_string(),
|
||||
Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(),
|
||||
Constant::Complex { real, imag } => format!("{real}+{imag}j"),
|
||||
_ => "any".to_string(),
|
||||
},
|
||||
fn parameter_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
|
||||
if let Some(name) = node.child_by_field_name("name") {
|
||||
return node_text(name, src);
|
||||
}
|
||||
|
||||
Expr::Name(n) => n.id.to_string(),
|
||||
let mut cursor = node.walk();
|
||||
node.named_children(&mut cursor)
|
||||
.find(|child| child.kind() == "identifier")
|
||||
.ok_or_else(|| anyhow!("parameter missing name"))
|
||||
.and_then(|name| node_text(name, src))
|
||||
}
|
||||
|
||||
Expr::UnaryOp(u) => {
|
||||
if matches!(u.op, UnaryOp::USub) {
|
||||
let inner = expr_to_str(&u.operand);
|
||||
if inner.parse::<f64>().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) {
|
||||
return format!("-{inner}");
|
||||
}
|
||||
fn get_arg_type(annotation: Option<Node<'_>>, src: &str) -> Result<String> {
|
||||
let Some(annotation) = annotation else {
|
||||
return Ok(String::new());
|
||||
};
|
||||
|
||||
match annotation.kind() {
|
||||
"type" => get_arg_type(named_child(annotation, 0), src),
|
||||
"generic_type" => {
|
||||
let value = annotation
|
||||
.child_by_field_name("type")
|
||||
.or_else(|| named_child(annotation, 0))
|
||||
.ok_or_else(|| anyhow!("generic_type missing value"))?;
|
||||
let value_name = if value.kind() == "identifier" {
|
||||
node_text(value, src)?
|
||||
} else {
|
||||
return Ok("any".to_string());
|
||||
};
|
||||
|
||||
let inner = annotation
|
||||
.child_by_field_name("type_parameter")
|
||||
.or_else(|| annotation.child_by_field_name("parameters"))
|
||||
.or_else(|| named_child(annotation, 1))
|
||||
.ok_or_else(|| anyhow!("generic_type missing inner type"))?;
|
||||
|
||||
match value_name {
|
||||
"Optional" => Ok(format!("{}?", generic_inner_type(inner, src)?)),
|
||||
"List" => Ok(format!("list[{}]", generic_inner_type(inner, src)?)),
|
||||
"Literal" => Ok(format!(
|
||||
"literal:{}",
|
||||
literal_members(inner, src)?.join("|")
|
||||
)),
|
||||
_ => Ok("any".to_string()),
|
||||
}
|
||||
"any".to_string()
|
||||
}
|
||||
"identifier" => Ok(node_text(annotation, src)?.to_string()),
|
||||
"subscript" => {
|
||||
let value = annotation
|
||||
.child_by_field_name("value")
|
||||
.or_else(|| named_child(annotation, 0))
|
||||
.ok_or_else(|| anyhow!("subscript missing value"))?;
|
||||
let value_name = if value.kind() == "identifier" {
|
||||
node_text(value, src)?
|
||||
} else {
|
||||
return Ok("any".to_string());
|
||||
};
|
||||
|
||||
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::<Vec<_>>().join(","),
|
||||
|
||||
_ => "any".to_string(),
|
||||
let inner = annotation
|
||||
.child_by_field_name("subscript")
|
||||
.or_else(|| annotation.child_by_field_name("slice"))
|
||||
.or_else(|| named_child(annotation, 1))
|
||||
.ok_or_else(|| anyhow!("subscript missing inner type"))?;
|
||||
match value_name {
|
||||
"Optional" => Ok(format!("{}?", get_arg_type(Some(inner), src)?)),
|
||||
"List" => Ok(format!("list[{}]", get_arg_type(Some(inner), src)?)),
|
||||
"Literal" => Ok(format!(
|
||||
"literal:{}",
|
||||
literal_members(inner, src)?.join("|")
|
||||
)),
|
||||
_ => Ok("any".to_string()),
|
||||
}
|
||||
}
|
||||
_ => Ok("any".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn literal_members(e: &Expr) -> Vec<String> {
|
||||
match e {
|
||||
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(),
|
||||
_ => vec![expr_to_str(e)],
|
||||
fn generic_inner_type(node: Node<'_>, src: &str) -> Result<String> {
|
||||
if node.kind() == "type_parameter" {
|
||||
return get_arg_type(named_child(node, 0), src);
|
||||
}
|
||||
|
||||
get_arg_type(Some(node), src)
|
||||
}
|
||||
|
||||
fn literal_members(node: Node<'_>, src: &str) -> Result<Vec<String>> {
|
||||
if node.kind() == "type" {
|
||||
return literal_members(
|
||||
named_child(node, 0).ok_or_else(|| anyhow!("type missing inner literal"))?,
|
||||
src,
|
||||
);
|
||||
}
|
||||
|
||||
if node.kind() == "tuple" || node.kind() == "type_parameter" {
|
||||
let mut cursor = node.walk();
|
||||
let members = node
|
||||
.named_children(&mut cursor)
|
||||
.map(|child| expr_to_str(child, src))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
return Ok(if members.is_empty() {
|
||||
vec!["any".to_string()]
|
||||
} else {
|
||||
members
|
||||
});
|
||||
}
|
||||
|
||||
Ok(vec![expr_to_str(node, src)?])
|
||||
}
|
||||
|
||||
fn expr_to_str(node: Node<'_>, src: &str) -> Result<String> {
|
||||
match node.kind() {
|
||||
"type" => expr_to_str(
|
||||
named_child(node, 0).ok_or_else(|| anyhow!("type missing expression"))?,
|
||||
src,
|
||||
),
|
||||
"string" | "integer" | "float" | "true" | "false" | "none" | "identifier"
|
||||
| "unary_operator" => Ok(node_text(node, src)?.trim().to_string()),
|
||||
_ => Ok("any".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn named_child(node: Node<'_>, index: usize) -> Option<Node<'_>> {
|
||||
let mut cursor = node.walk();
|
||||
node.named_children(&mut cursor).nth(index)
|
||||
}
|
||||
|
||||
fn node_text<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
|
||||
node.utf8_text(src.as_bytes())
|
||||
.map_err(|err| anyhow!("invalid utf-8 in python source: {err}"))
|
||||
}
|
||||
|
||||
fn parse_docstring_args(doc: &str) -> IndexMap<String, (String, String)> {
|
||||
@@ -417,3 +520,413 @@ fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) {
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn parse_source(
|
||||
source: &str,
|
||||
file_name: &str,
|
||||
parent: &Path,
|
||||
) -> Result<Vec<FunctionDeclaration>> {
|
||||
let unique = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time went backwards")
|
||||
.as_nanos();
|
||||
let path = std::env::temp_dir().join(format!("loki_python_parser_{file_name}_{unique}.py"));
|
||||
fs::write(&path, source).expect("failed to write temp python source");
|
||||
let file = File::open(&path).expect("failed to open temp python source");
|
||||
let result = generate_python_declarations(file, file_name, Some(parent));
|
||||
let _ = fs::remove_file(&path);
|
||||
result
|
||||
}
|
||||
|
||||
fn properties(schema: &JsonSchema) -> &IndexMap<String, JsonSchema> {
|
||||
schema
|
||||
.properties
|
||||
.as_ref()
|
||||
.expect("missing schema properties")
|
||||
}
|
||||
|
||||
fn property<'a>(schema: &'a JsonSchema, name: &str) -> &'a JsonSchema {
|
||||
properties(schema)
|
||||
.get(name)
|
||||
.unwrap_or_else(|| panic!("missing property: {name}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_demo_py() {
|
||||
let source = r#"
|
||||
import os
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
def run(
|
||||
string: str,
|
||||
string_enum: Literal["foo", "bar"],
|
||||
boolean: bool,
|
||||
integer: int,
|
||||
number: float,
|
||||
array: List[str],
|
||||
string_optional: Optional[str] = None,
|
||||
array_optional: Optional[List[str]] = None,
|
||||
):
|
||||
"""Demonstrates how to create a tool using Python and how to use comments.
|
||||
Args:
|
||||
string: Define a required string property
|
||||
string_enum: Define a required string property with enum
|
||||
boolean: Define a required boolean property
|
||||
integer: Define a required integer property
|
||||
number: Define a required number property
|
||||
array: Define a required string array property
|
||||
string_optional: Define an optional string property
|
||||
array_optional: Define an optional string array property
|
||||
"""
|
||||
output = f"""string: {string}
|
||||
string_enum: {string_enum}
|
||||
string_optional: {string_optional}
|
||||
boolean: {boolean}
|
||||
integer: {integer}
|
||||
number: {number}
|
||||
array: {array}
|
||||
array_optional: {array_optional}"""
|
||||
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("LLM_"):
|
||||
output = f"{output}\n{key}: {value}"
|
||||
|
||||
return output
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "demo_py", Path::new("tools")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
|
||||
let decl = &declarations[0];
|
||||
assert_eq!(decl.name, "demo_py");
|
||||
assert!(!decl.agent);
|
||||
assert!(decl.description.starts_with("Demonstrates how to create"));
|
||||
|
||||
let params = &decl.parameters;
|
||||
assert_eq!(params.type_value.as_deref(), Some("object"));
|
||||
assert_eq!(
|
||||
params.required.as_ref().unwrap(),
|
||||
&vec![
|
||||
"string".to_string(),
|
||||
"string_enum".to_string(),
|
||||
"boolean".to_string(),
|
||||
"integer".to_string(),
|
||||
"number".to_string(),
|
||||
"array".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
property(params, "string").type_value.as_deref(),
|
||||
Some("string")
|
||||
);
|
||||
|
||||
let string_enum = property(params, "string_enum");
|
||||
assert_eq!(string_enum.type_value.as_deref(), Some("string"));
|
||||
assert_eq!(
|
||||
string_enum.enum_value.as_ref().unwrap(),
|
||||
&vec!["foo".to_string(), "bar".to_string()]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
property(params, "boolean").type_value.as_deref(),
|
||||
Some("boolean")
|
||||
);
|
||||
assert_eq!(
|
||||
property(params, "integer").type_value.as_deref(),
|
||||
Some("integer")
|
||||
);
|
||||
assert_eq!(
|
||||
property(params, "number").type_value.as_deref(),
|
||||
Some("number")
|
||||
);
|
||||
|
||||
let array = property(params, "array");
|
||||
assert_eq!(array.type_value.as_deref(), Some("array"));
|
||||
assert_eq!(
|
||||
array.items.as_ref().unwrap().type_value.as_deref(),
|
||||
Some("string")
|
||||
);
|
||||
|
||||
let string_optional = property(params, "string_optional");
|
||||
assert_eq!(string_optional.type_value.as_deref(), Some("string"));
|
||||
assert!(
|
||||
!params
|
||||
.required
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains(&"string_optional".to_string())
|
||||
);
|
||||
|
||||
let array_optional = property(params, "array_optional");
|
||||
assert_eq!(array_optional.type_value.as_deref(), Some("array"));
|
||||
assert_eq!(
|
||||
array_optional.items.as_ref().unwrap().type_value.as_deref(),
|
||||
Some("string")
|
||||
);
|
||||
assert!(
|
||||
!params
|
||||
.required
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains(&"array_optional".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_weather() {
|
||||
let source = r#"
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
def run(
|
||||
location: str,
|
||||
llm_output: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get the current weather in a given location
|
||||
|
||||
Args:
|
||||
location (str): The city and optionally the state or country (e.g., "London", "San Francisco, CA").
|
||||
|
||||
Returns:
|
||||
str: A single-line formatted weather string from wttr.in (``format=4`` with metric units).
|
||||
"""
|
||||
url = f"https://wttr.in/{quote_plus(location)}?format=4&M"
|
||||
|
||||
with urlopen(url, timeout=10) as resp:
|
||||
weather = resp.read().decode("utf-8", errors="replace")
|
||||
|
||||
dest = llm_output if llm_output is not None else os.environ.get("LLM_OUTPUT", "/dev/stdout")
|
||||
|
||||
if dest not in {"-", "/dev/stdout"}:
|
||||
path = Path(dest)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(weather)
|
||||
else:
|
||||
pass
|
||||
|
||||
return weather
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "get_current_weather", Path::new("tools")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
|
||||
let decl = &declarations[0];
|
||||
assert_eq!(decl.name, "get_current_weather");
|
||||
assert!(!decl.agent);
|
||||
assert!(
|
||||
decl.description
|
||||
.starts_with("Get the current weather in a given location")
|
||||
);
|
||||
|
||||
let params = &decl.parameters;
|
||||
assert_eq!(
|
||||
params.required.as_ref().unwrap(),
|
||||
&vec!["location".to_string()]
|
||||
);
|
||||
|
||||
let location = property(params, "location");
|
||||
assert_eq!(location.type_value.as_deref(), Some("string"));
|
||||
assert_eq!(
|
||||
location.description.as_deref(),
|
||||
Some(
|
||||
"The city and optionally the state or country (e.g., \"London\", \"San Francisco, CA\")."
|
||||
)
|
||||
);
|
||||
|
||||
let llm_output = property(params, "llm_output");
|
||||
assert_eq!(llm_output.type_value.as_deref(), Some("string"));
|
||||
assert!(
|
||||
!params
|
||||
.required
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains(&"llm_output".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_execute_py_code() {
|
||||
let source = r#"
|
||||
import ast
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
|
||||
def run(code: str):
|
||||
"""Execute the given Python code.
|
||||
Args:
|
||||
code: The Python code to execute, such as `print("hello world")`
|
||||
"""
|
||||
output = io.StringIO()
|
||||
with redirect_stdout(output):
|
||||
value = exec_with_return(code, {}, {})
|
||||
|
||||
if value is not None:
|
||||
output.write(str(value))
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def exec_with_return(code: str, globals: dict, locals: dict):
|
||||
a = ast.parse(code)
|
||||
last_expression = None
|
||||
if a.body:
|
||||
if isinstance(a_last := a.body[-1], ast.Expr):
|
||||
last_expression = ast.unparse(a.body.pop())
|
||||
elif isinstance(a_last, ast.Assign):
|
||||
last_expression = ast.unparse(a_last.targets[0])
|
||||
elif isinstance(a_last, (ast.AnnAssign, ast.AugAssign)):
|
||||
last_expression = ast.unparse(a_last.target)
|
||||
exec(ast.unparse(a), globals, locals)
|
||||
if last_expression:
|
||||
return eval(last_expression, globals, locals)
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "execute_py_code", Path::new("tools")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
|
||||
let decl = &declarations[0];
|
||||
assert_eq!(decl.name, "execute_py_code");
|
||||
assert!(!decl.agent);
|
||||
|
||||
let params = &decl.parameters;
|
||||
assert_eq!(properties(params).len(), 1);
|
||||
let code = property(params, "code");
|
||||
assert_eq!(code.type_value.as_deref(), Some("string"));
|
||||
assert_eq!(
|
||||
code.description.as_deref(),
|
||||
Some("The Python code to execute, such as `print(\"hello world\")`")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_tools() {
|
||||
let source = r#"
|
||||
import urllib.request
|
||||
|
||||
def get_ipinfo():
|
||||
"""
|
||||
Get the ip info
|
||||
"""
|
||||
with urllib.request.urlopen("https://httpbin.org/ip") as response:
|
||||
data = response.read()
|
||||
return data.decode('utf-8')
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
|
||||
let decl = &declarations[0];
|
||||
assert_eq!(decl.name, "get_ipinfo");
|
||||
assert!(decl.agent);
|
||||
assert_eq!(decl.description, "Get the ip info");
|
||||
assert!(properties(&decl.parameters).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_varargs() {
|
||||
let source = r#"
|
||||
def run(*args):
|
||||
"""Has docstring"""
|
||||
return args
|
||||
"#;
|
||||
|
||||
let err = parse_source(source, "reject_varargs", Path::new("tools")).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_kwargs() {
|
||||
let source = r#"
|
||||
def run(**kwargs):
|
||||
"""Has docstring"""
|
||||
return kwargs
|
||||
"#;
|
||||
|
||||
let err = parse_source(source, "reject_kwargs", Path::new("tools")).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_positional_only() {
|
||||
let source = r#"
|
||||
def run(x, /, y):
|
||||
"""Has docstring"""
|
||||
return x + y
|
||||
"#;
|
||||
|
||||
let err = parse_source(source, "reject_positional_only", Path::new("tools")).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_docstring() {
|
||||
let source = r#"
|
||||
def run(x: str):
|
||||
pass
|
||||
"#;
|
||||
|
||||
let err = parse_source(source, "missing_docstring", Path::new("tools")).unwrap_err();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Missing or empty description on function: run")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_syntax_error() {
|
||||
let source = "def run(: broken";
|
||||
let err = parse_source(source, "syntax_error", Path::new("tools")).unwrap_err();
|
||||
assert!(err.to_string().contains("failed to parse python"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_underscore_functions_skipped() {
|
||||
let source = r#"
|
||||
def _private():
|
||||
"""Private"""
|
||||
return None
|
||||
|
||||
def public():
|
||||
"""Public"""
|
||||
return None
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
assert_eq!(declarations[0].name, "public");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_instructions_not_skipped() {
|
||||
let source = r#"
|
||||
def _instructions():
|
||||
"""Help text"""
|
||||
return None
|
||||
"#;
|
||||
|
||||
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
|
||||
assert_eq!(declarations.len(), 1);
|
||||
assert_eq!(declarations[0].name, "instructions");
|
||||
assert_eq!(declarations[0].description, "Help text");
|
||||
assert!(declarations[0].agent);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user