refactor: python tools now use tree-sitter queries instead of AST

This commit is contained in:
2026-04-09 10:20:49 -06:00
parent ab2b927fcb
commit ebeb9c9b7d
3 changed files with 749 additions and 507 deletions
Generated
+42 -313
View File
@@ -859,7 +859,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.2",
"rustc-hash",
"shlex",
"syn",
]
@@ -1388,12 +1388,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "crunchy"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
[[package]]
name = "crypto-common"
version = "0.1.7"
@@ -1518,34 +1512,13 @@ dependencies = [
"syn",
]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl 1.0.0",
]
[[package]]
name = "derive_more"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134"
dependencies = [
"derive_more-impl 2.1.1",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn",
"unicode-xid",
"derive_more-impl",
]
[[package]]
@@ -2105,15 +2078,6 @@ dependencies = [
"windows-link",
]
[[package]]
name = "getopts"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df"
dependencies = [
"unicode-width",
]
[[package]]
name = "getrandom"
version = "0.2.17"
@@ -2258,15 +2222,6 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -2822,18 +2777,6 @@ dependencies = [
"once_cell",
]
[[package]]
name = "is-macro"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "is-terminal"
version = "0.4.17"
@@ -2870,15 +2813,6 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
@@ -2987,12 +2921,6 @@ dependencies = [
"simple_asn1",
]
[[package]]
name = "lalrpop-util"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
[[package]]
name = "lazy_static"
version = "1.5.0"
@@ -3021,12 +2949,6 @@ dependencies = [
"windows-link",
]
[[package]]
name = "libm"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
[[package]]
name = "libredox"
version = "0.1.15"
@@ -3171,8 +3093,6 @@ dependencies = [
"reqwest-eventsource",
"rmcp",
"rust-embed",
"rustpython-ast",
"rustpython-parser",
"scraper",
"serde",
"serde_json",
@@ -3188,6 +3108,8 @@ dependencies = [
"tokio",
"tokio-graceful",
"tokio-stream",
"tree-sitter",
"tree-sitter-python",
"unicode-segmentation",
"unicode-width",
"url",
@@ -3233,64 +3155,6 @@ dependencies = [
"libc",
]
[[package]]
name = "malachite"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5"
dependencies = [
"malachite-base",
"malachite-nz",
"malachite-q",
]
[[package]]
name = "malachite-base"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb"
dependencies = [
"hashbrown 0.14.5",
"itertools 0.11.0",
"libm",
"ryu",
]
[[package]]
name = "malachite-bigint"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8"
dependencies = [
"derive_more 1.0.0",
"malachite",
"num-integer",
"num-traits",
"paste",
]
[[package]]
name = "malachite-nz"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7"
dependencies = [
"itertools 0.11.0",
"libm",
"malachite-base",
]
[[package]]
name = "malachite-q"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191"
dependencies = [
"itertools 0.11.0",
"malachite-base",
"malachite-nz",
]
[[package]]
name = "markup5ever"
version = "0.12.1"
@@ -3945,12 +3809,6 @@ dependencies = [
"subtle",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "pastey"
version = "0.2.1"
@@ -4293,7 +4151,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.2",
"rustc-hash",
"rustls 0.23.37",
"socket2 0.6.3",
"thiserror 2.0.18",
@@ -4313,7 +4171,7 @@ dependencies = [
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash 2.1.2",
"rustc-hash",
"rustls 0.23.37",
"rustls-pki-types",
"slab",
@@ -4364,8 +4222,6 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
@@ -4375,20 +4231,10 @@ version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha 0.9.0",
"rand_chacha",
"rand_core 0.9.5",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
@@ -4732,12 +4578,6 @@ version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.2"
@@ -4851,63 +4691,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustpython-ast"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5"
dependencies = [
"is-macro",
"malachite-bigint",
"rustpython-parser-core",
"static_assertions",
]
[[package]]
name = "rustpython-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb"
dependencies = [
"anyhow",
"is-macro",
"itertools 0.11.0",
"lalrpop-util",
"log",
"malachite-bigint",
"num-traits",
"phf",
"phf_codegen",
"rustc-hash 1.1.0",
"rustpython-ast",
"rustpython-parser-core",
"tiny-keccak",
"unic-emoji-char",
"unic-ucd-ident",
"unicode_names2",
]
[[package]]
name = "rustpython-parser-core"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad"
dependencies = [
"is-macro",
"memchr",
"rustpython-parser-vendored",
]
[[package]]
name = "rustpython-parser-vendored"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6"
dependencies = [
"memchr",
"once_cell",
]
[[package]]
name = "rustversion"
version = "1.0.22"
@@ -5388,12 +5171,6 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stop-words"
version = "0.9.0"
@@ -5403,6 +5180,12 @@ dependencies = [
"serde_json",
]
[[package]]
name = "streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520"
[[package]]
name = "string_cache"
version = "0.8.9"
@@ -5739,15 +5522,6 @@ dependencies = [
"time-core",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]]
name = "tinystr"
version = "0.8.3"
@@ -6053,6 +5827,35 @@ dependencies = [
"tracing-log",
]
[[package]]
name = "tree-sitter"
version = "0.24.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5387dffa7ffc7d2dae12b50c6f7aab8ff79d6210147c6613561fc3d474c6f75"
dependencies = [
"cc",
"regex",
"regex-syntax",
"streaming-iterator",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-language"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782"
[[package]]
name = "tree-sitter-python"
version = "0.23.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d065aaa27f3aaceaf60c1f0e0ac09e1cb9eb8ed28e7bcdaa52129cffc7f4b04"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree_magic_mini"
version = "3.2.2"
@@ -6136,58 +5939,6 @@ dependencies = [
"syn",
]
[[package]]
name = "unic-char-property"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221"
dependencies = [
"unic-char-range",
]
[[package]]
name = "unic-char-range"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc"
[[package]]
name = "unic-common"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc"
[[package]]
name = "unic-emoji-char"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-ident"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-version"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4"
dependencies = [
"unic-common",
]
[[package]]
name = "unicase"
version = "2.9.0"
@@ -6224,28 +5975,6 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
dependencies = [
"phf",
"unicode_names2_generator",
]
[[package]]
name = "unicode_names2_generator"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
dependencies = [
"getopts",
"log",
"phf_codegen",
"rand 0.8.5",
]
[[package]]
name = "universal-hash"
version = "0.5.1"
+2 -2
View File
@@ -91,8 +91,8 @@ strum_macros = "0.27.2"
indoc = "2.0.6"
rmcp = { version = "0.16.0", features = ["client", "transport-child-process"] }
num_cpus = "1.17.0"
rustpython-parser = "0.4.0"
rustpython-ast = "0.4.0"
tree-sitter = "0.24"
tree-sitter-python = "0.23"
colored = "3.0.0"
clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] }
gman = "0.3.0"
+681 -168
View File
@@ -1,13 +1,11 @@
use crate::function::{FunctionDeclaration, JsonSchema};
use anyhow::{Context, Result, bail};
use ast::{Stmt, StmtFunctionDef};
use anyhow::{Context, Result, anyhow, bail};
use indexmap::IndexMap;
use rustpython_ast::{Constant, Expr, UnaryOp};
use rustpython_parser::{Mode, ast};
use serde_json::Value;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use tree_sitter::{Node, Parser, Tree};
#[derive(Debug)]
struct Param {
@@ -28,46 +26,49 @@ pub fn generate_python_declarations(
tool_file
.read_to_string(&mut src)
.with_context(|| format!("Failed to load script at '{tool_file:?}'"))?;
let suite = parse_suite(&src, file_name)?;
let tree = parse_tree(&src, file_name)?;
let is_tool = parent
.and_then(|p| p.file_name())
.is_some_and(|n| n == "tools");
let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?;
if is_tool {
for d in &mut declarations {
d.agent = true;
}
}
Ok(declarations)
python_to_function_declarations(file_name, &src, &tree, is_tool)
}
fn parse_suite(src: &str, filename: &str) -> Result<ast::Suite> {
let mod_ast =
rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?;
fn parse_tree(src: &str, filename: &str) -> Result<Tree> {
let mut parser = Parser::new();
let language = tree_sitter_python::LANGUAGE.into();
parser
.set_language(&language)
.context("failed to initialize python tree-sitter parser")?;
let suite = match mod_ast {
ast::Mod::Module(m) => m.body,
ast::Mod::Interactive(m) => m.body,
ast::Mod::Expression(_) => bail!("expected a module; got a single expression"),
_ => bail!("unexpected parse mode/AST variant"),
};
let tree = parser
.parse(src.as_bytes(), None)
.ok_or_else(|| anyhow!("failed to parse python: {filename}"))?;
Ok(suite)
if tree.root_node().has_error() {
bail!("failed to parse python: syntax error in {filename}");
}
Ok(tree)
}
fn python_to_function_declarations(
file_name: &str,
module: &ast::Suite,
src: &str,
tree: &Tree,
is_tool: bool,
) -> Result<Vec<FunctionDeclaration>> {
let mut out = Vec::new();
let root = tree.root_node();
let mut cursor = root.walk();
for stmt in module {
if let Stmt::FunctionDef(fd) = stmt {
let func_name = fd.name.to_string();
for stmt in root.named_children(&mut cursor) {
let Some(fd) = unwrap_function_definition(stmt) else {
continue;
};
let func_name = function_name(fd, src)?.to_string();
if func_name.starts_with('_') && func_name != "_instructions" {
continue;
@@ -77,8 +78,8 @@ fn python_to_function_declarations(
continue;
}
let description = get_docstring_from_body(&fd.body).unwrap_or_default();
let params = collect_params(fd);
let description = get_docstring_from_function(fd, src).unwrap_or_default();
let params = collect_params(fd, src)?;
let schema = build_parameters_schema(&params, &description);
let name = if is_tool && func_name == "run" {
underscore(file_name)
@@ -97,109 +98,125 @@ fn python_to_function_declarations(
agent: !is_tool,
});
}
}
Ok(out)
}
fn get_docstring_from_body(body: &[Stmt]) -> Option<String> {
let first = body.first()?;
if let Stmt::Expr(expr_stmt) = first
&& let Expr::Constant(constant) = &*expr_stmt.value
&& let Constant::Str(s) = &constant.value
{
return Some(s.clone());
fn unwrap_function_definition(node: Node<'_>) -> Option<Node<'_>> {
match node.kind() {
"function_definition" => Some(node),
"decorated_definition" => {
let mut cursor = node.walk();
node.named_children(&mut cursor)
.find(|child| child.kind() == "function_definition")
}
_ => None,
}
None
}
fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
fn function_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
let name_node = node
.child_by_field_name("name")
.ok_or_else(|| anyhow!("function_definition missing name"))?;
node_text(name_node, src)
}
fn get_docstring_from_function(node: Node<'_>, src: &str) -> Option<String> {
let body = node.child_by_field_name("body")?;
let mut cursor = body.walk();
let first = body.named_children(&mut cursor).next()?;
if first.kind() != "expression_statement" {
return None;
}
let mut expr_cursor = first.walk();
let expr = first.named_children(&mut expr_cursor).next()?;
if expr.kind() != "string" {
return None;
}
let text = node_text(expr, src).ok()?;
strip_string_quotes(text)
}
fn strip_string_quotes(text: &str) -> Option<String> {
let quote_offset = text
.char_indices()
.find_map(|(idx, ch)| (ch == '\'' || ch == '"').then_some(idx))?;
let prefix = &text[..quote_offset];
if !prefix.chars().all(|ch| ch.is_ascii_alphabetic()) {
return None;
}
if prefix.chars().any(|ch| ch == 'f' || ch == 'F') {
return None;
}
let literal = &text[quote_offset..];
let quote = if literal.starts_with("\"\"\"") {
"\"\"\""
} else if literal.starts_with("'''") {
"'''"
} else if literal.starts_with('"') {
"\""
} else if literal.starts_with('\'') {
"'"
} else {
return None;
};
if literal.len() < quote.len() * 2 || !literal.ends_with(quote) {
return None;
}
Some(literal[quote.len()..literal.len() - quote.len()].to_string())
}
fn collect_params(node: Node<'_>, src: &str) -> Result<Vec<Param>> {
let parameters = node
.child_by_field_name("parameters")
.ok_or_else(|| anyhow!("function_definition missing parameters"))?;
let mut out = Vec::new();
let mut cursor = parameters.walk();
for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) {
let name = a.def.arg.to_string();
let mut ty = get_arg_type(a.def.annotation.as_deref());
let mut required = a.default.is_none();
if ty.ends_with('?') {
ty.pop();
required = false;
}
let default = if a.default.is_some() {
Some(Value::Null)
} else {
None
};
out.push(Param {
name,
ty_hint: ty,
required,
default,
doc_type: None,
doc_desc: None,
});
}
for a in &fd.args.kwonlyargs {
let name = a.def.arg.to_string();
let mut ty = get_arg_type(a.def.annotation.as_deref());
let mut required = a.default.is_none();
if ty.ends_with('?') {
ty.pop();
required = false;
}
let default = if a.default.is_some() {
Some(Value::Null)
} else {
None
};
out.push(Param {
name,
ty_hint: ty,
required,
default,
doc_type: None,
doc_desc: None,
});
}
if let Some(vararg) = &fd.args.vararg {
let name = vararg.arg.to_string();
let inner = get_arg_type(vararg.annotation.as_deref());
let ty = if inner.is_empty() {
"list[str]".into()
} else {
format!("list[{inner}]")
};
out.push(Param {
name,
ty_hint: ty,
required: false,
for param in parameters.named_children(&mut cursor) {
match param.kind() {
"identifier" => out.push(Param {
name: node_text(param, src)?.to_string(),
ty_hint: String::new(),
required: true,
default: None,
doc_type: None,
doc_desc: None,
});
}),
"typed_parameter" => out.push(build_param(
parameter_name(param, src)?,
get_arg_type(param.child_by_field_name("type"), src)?,
true,
None,
)),
"default_parameter" => out.push(build_param(
parameter_name(param, src)?,
String::new(),
false,
Some(Value::Null),
)),
"typed_default_parameter" => out.push(build_param(
parameter_name(param, src)?,
get_arg_type(param.child_by_field_name("type"), src)?,
false,
Some(Value::Null),
)),
"list_splat_pattern" | "dictionary_splat_pattern" | "positional_separator" => {
bail!(
"Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"
)
}
"keyword_separator" => continue,
other => bail!("Unsupported parameter type: {other}"),
}
}
if let Some(kwarg) = &fd.args.kwarg {
let name = kwarg.arg.to_string();
out.push(Param {
name,
ty_hint: "object".into(),
required: false,
default: None,
doc_type: None,
doc_desc: None,
});
}
if let Some(doc) = get_docstring_from_body(&fd.body) {
if let Some(doc) = get_docstring_from_function(node, src) {
let meta = parse_docstring_args(&doc);
for p in &mut out {
if let Some((t, d)) = meta.get(&p.name) {
@@ -218,69 +235,155 @@ fn collect_params(fd: &StmtFunctionDef) -> Vec<Param> {
}
}
out
Ok(out)
}
fn get_arg_type(annotation: Option<&Expr>) -> String {
match annotation {
None => "".to_string(),
Some(Expr::Name(n)) => n.id.to_string(),
Some(Expr::Subscript(sub)) => match &*sub.value {
Expr::Name(name) if &name.id == "Optional" => {
let inner = get_arg_type(Some(&sub.slice));
format!("{inner}?")
fn build_param(name: &str, mut ty: String, mut required: bool, default: Option<Value>) -> Param {
if ty.ends_with('?') {
ty.pop();
required = false;
}
Expr::Name(name) if &name.id == "List" => {
let inner = get_arg_type(Some(&sub.slice));
format!("list[{inner}]")
}
Expr::Name(name) if &name.id == "Literal" => {
let vals = literal_members(&sub.slice);
format!("literal:{}", vals.join("|"))
}
_ => "any".to_string(),
},
_ => "any".to_string(),
Param {
name: name.to_string(),
ty_hint: ty,
required,
default,
doc_type: None,
doc_desc: None,
}
}
fn expr_to_str(e: &Expr) -> String {
match e {
Expr::Constant(c) => match &c.value {
Constant::Str(s) => s.clone(),
Constant::Int(i) => i.to_string(),
Constant::Float(f) => f.to_string(),
Constant::Bool(b) => b.to_string(),
Constant::None => "None".to_string(),
Constant::Ellipsis => "...".to_string(),
Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(),
Constant::Complex { real, imag } => format!("{real}+{imag}j"),
_ => "any".to_string(),
},
Expr::Name(n) => n.id.to_string(),
Expr::UnaryOp(u) => {
if matches!(u.op, UnaryOp::USub) {
let inner = expr_to_str(&u.operand);
if inner.parse::<f64>().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) {
return format!("-{inner}");
}
}
"any".to_string()
fn parameter_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
if let Some(name) = node.child_by_field_name("name") {
return node_text(name, src);
}
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::<Vec<_>>().join(","),
let mut cursor = node.walk();
node.named_children(&mut cursor)
.find(|child| child.kind() == "identifier")
.ok_or_else(|| anyhow!("parameter missing name"))
.and_then(|name| node_text(name, src))
}
_ => "any".to_string(),
fn get_arg_type(annotation: Option<Node<'_>>, src: &str) -> Result<String> {
let Some(annotation) = annotation else {
return Ok(String::new());
};
match annotation.kind() {
"type" => get_arg_type(named_child(annotation, 0), src),
"generic_type" => {
let value = annotation
.child_by_field_name("type")
.or_else(|| named_child(annotation, 0))
.ok_or_else(|| anyhow!("generic_type missing value"))?;
let value_name = if value.kind() == "identifier" {
node_text(value, src)?
} else {
return Ok("any".to_string());
};
let inner = annotation
.child_by_field_name("type_parameter")
.or_else(|| annotation.child_by_field_name("parameters"))
.or_else(|| named_child(annotation, 1))
.ok_or_else(|| anyhow!("generic_type missing inner type"))?;
match value_name {
"Optional" => Ok(format!("{}?", generic_inner_type(inner, src)?)),
"List" => Ok(format!("list[{}]", generic_inner_type(inner, src)?)),
"Literal" => Ok(format!(
"literal:{}",
literal_members(inner, src)?.join("|")
)),
_ => Ok("any".to_string()),
}
}
"identifier" => Ok(node_text(annotation, src)?.to_string()),
"subscript" => {
let value = annotation
.child_by_field_name("value")
.or_else(|| named_child(annotation, 0))
.ok_or_else(|| anyhow!("subscript missing value"))?;
let value_name = if value.kind() == "identifier" {
node_text(value, src)?
} else {
return Ok("any".to_string());
};
let inner = annotation
.child_by_field_name("subscript")
.or_else(|| annotation.child_by_field_name("slice"))
.or_else(|| named_child(annotation, 1))
.ok_or_else(|| anyhow!("subscript missing inner type"))?;
match value_name {
"Optional" => Ok(format!("{}?", get_arg_type(Some(inner), src)?)),
"List" => Ok(format!("list[{}]", get_arg_type(Some(inner), src)?)),
"Literal" => Ok(format!(
"literal:{}",
literal_members(inner, src)?.join("|")
)),
_ => Ok("any".to_string()),
}
}
_ => Ok("any".to_string()),
}
}
fn literal_members(e: &Expr) -> Vec<String> {
match e {
Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(),
_ => vec![expr_to_str(e)],
fn generic_inner_type(node: Node<'_>, src: &str) -> Result<String> {
if node.kind() == "type_parameter" {
return get_arg_type(named_child(node, 0), src);
}
get_arg_type(Some(node), src)
}
fn literal_members(node: Node<'_>, src: &str) -> Result<Vec<String>> {
if node.kind() == "type" {
return literal_members(
named_child(node, 0).ok_or_else(|| anyhow!("type missing inner literal"))?,
src,
);
}
if node.kind() == "tuple" || node.kind() == "type_parameter" {
let mut cursor = node.walk();
let members = node
.named_children(&mut cursor)
.map(|child| expr_to_str(child, src))
.collect::<Result<Vec<_>>>()?;
return Ok(if members.is_empty() {
vec!["any".to_string()]
} else {
members
});
}
Ok(vec![expr_to_str(node, src)?])
}
fn expr_to_str(node: Node<'_>, src: &str) -> Result<String> {
match node.kind() {
"type" => expr_to_str(
named_child(node, 0).ok_or_else(|| anyhow!("type missing expression"))?,
src,
),
"string" | "integer" | "float" | "true" | "false" | "none" | "identifier"
| "unary_operator" => Ok(node_text(node, src)?.trim().to_string()),
_ => Ok("any".to_string()),
}
}
fn named_child(node: Node<'_>, index: usize) -> Option<Node<'_>> {
let mut cursor = node.walk();
node.named_children(&mut cursor).nth(index)
}
fn node_text<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> {
node.utf8_text(src.as_bytes())
.map_err(|err| anyhow!("invalid utf-8 in python source: {err}"))
}
fn parse_docstring_args(doc: &str) -> IndexMap<String, (String, String)> {
@@ -417,3 +520,413 @@ fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) {
.into(),
);
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
fn parse_source(
source: &str,
file_name: &str,
parent: &Path,
) -> Result<Vec<FunctionDeclaration>> {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_nanos();
let path = std::env::temp_dir().join(format!("loki_python_parser_{file_name}_{unique}.py"));
fs::write(&path, source).expect("failed to write temp python source");
let file = File::open(&path).expect("failed to open temp python source");
let result = generate_python_declarations(file, file_name, Some(parent));
let _ = fs::remove_file(&path);
result
}
fn properties(schema: &JsonSchema) -> &IndexMap<String, JsonSchema> {
schema
.properties
.as_ref()
.expect("missing schema properties")
}
fn property<'a>(schema: &'a JsonSchema, name: &str) -> &'a JsonSchema {
properties(schema)
.get(name)
.unwrap_or_else(|| panic!("missing property: {name}"))
}
#[test]
fn test_tool_demo_py() {
let source = r#"
import os
from typing import List, Literal, Optional
def run(
string: str,
string_enum: Literal["foo", "bar"],
boolean: bool,
integer: int,
number: float,
array: List[str],
string_optional: Optional[str] = None,
array_optional: Optional[List[str]] = None,
):
"""Demonstrates how to create a tool using Python and how to use comments.
Args:
string: Define a required string property
string_enum: Define a required string property with enum
boolean: Define a required boolean property
integer: Define a required integer property
number: Define a required number property
array: Define a required string array property
string_optional: Define an optional string property
array_optional: Define an optional string array property
"""
output = f"""string: {string}
string_enum: {string_enum}
string_optional: {string_optional}
boolean: {boolean}
integer: {integer}
number: {number}
array: {array}
array_optional: {array_optional}"""
for key, value in os.environ.items():
if key.startswith("LLM_"):
output = f"{output}\n{key}: {value}"
return output
"#;
let declarations = parse_source(source, "demo_py", Path::new("tools")).unwrap();
assert_eq!(declarations.len(), 1);
let decl = &declarations[0];
assert_eq!(decl.name, "demo_py");
assert!(!decl.agent);
assert!(decl.description.starts_with("Demonstrates how to create"));
let params = &decl.parameters;
assert_eq!(params.type_value.as_deref(), Some("object"));
assert_eq!(
params.required.as_ref().unwrap(),
&vec![
"string".to_string(),
"string_enum".to_string(),
"boolean".to_string(),
"integer".to_string(),
"number".to_string(),
"array".to_string(),
]
);
assert_eq!(
property(params, "string").type_value.as_deref(),
Some("string")
);
let string_enum = property(params, "string_enum");
assert_eq!(string_enum.type_value.as_deref(), Some("string"));
assert_eq!(
string_enum.enum_value.as_ref().unwrap(),
&vec!["foo".to_string(), "bar".to_string()]
);
assert_eq!(
property(params, "boolean").type_value.as_deref(),
Some("boolean")
);
assert_eq!(
property(params, "integer").type_value.as_deref(),
Some("integer")
);
assert_eq!(
property(params, "number").type_value.as_deref(),
Some("number")
);
let array = property(params, "array");
assert_eq!(array.type_value.as_deref(), Some("array"));
assert_eq!(
array.items.as_ref().unwrap().type_value.as_deref(),
Some("string")
);
let string_optional = property(params, "string_optional");
assert_eq!(string_optional.type_value.as_deref(), Some("string"));
assert!(
!params
.required
.as_ref()
.unwrap()
.contains(&"string_optional".to_string())
);
let array_optional = property(params, "array_optional");
assert_eq!(array_optional.type_value.as_deref(), Some("array"));
assert_eq!(
array_optional.items.as_ref().unwrap().type_value.as_deref(),
Some("string")
);
assert!(
!params
.required
.as_ref()
.unwrap()
.contains(&"array_optional".to_string())
);
}
#[test]
fn test_tool_weather() {
let source = r#"
import os
from pathlib import Path
from typing import Optional
from urllib.parse import quote_plus
from urllib.request import urlopen
def run(
location: str,
llm_output: Optional[str] = None,
) -> str:
"""Get the current weather in a given location
Args:
location (str): The city and optionally the state or country (e.g., "London", "San Francisco, CA").
Returns:
str: A single-line formatted weather string from wttr.in (``format=4`` with metric units).
"""
url = f"https://wttr.in/{quote_plus(location)}?format=4&M"
with urlopen(url, timeout=10) as resp:
weather = resp.read().decode("utf-8", errors="replace")
dest = llm_output if llm_output is not None else os.environ.get("LLM_OUTPUT", "/dev/stdout")
if dest not in {"-", "/dev/stdout"}:
path = Path(dest)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fh:
fh.write(weather)
else:
pass
return weather
"#;
let declarations = parse_source(source, "get_current_weather", Path::new("tools")).unwrap();
assert_eq!(declarations.len(), 1);
let decl = &declarations[0];
assert_eq!(decl.name, "get_current_weather");
assert!(!decl.agent);
assert!(
decl.description
.starts_with("Get the current weather in a given location")
);
let params = &decl.parameters;
assert_eq!(
params.required.as_ref().unwrap(),
&vec!["location".to_string()]
);
let location = property(params, "location");
assert_eq!(location.type_value.as_deref(), Some("string"));
assert_eq!(
location.description.as_deref(),
Some(
"The city and optionally the state or country (e.g., \"London\", \"San Francisco, CA\")."
)
);
let llm_output = property(params, "llm_output");
assert_eq!(llm_output.type_value.as_deref(), Some("string"));
assert!(
!params
.required
.as_ref()
.unwrap()
.contains(&"llm_output".to_string())
);
}
#[test]
fn test_tool_execute_py_code() {
let source = r#"
import ast
import io
from contextlib import redirect_stdout
def run(code: str):
"""Execute the given Python code.
Args:
code: The Python code to execute, such as `print("hello world")`
"""
output = io.StringIO()
with redirect_stdout(output):
value = exec_with_return(code, {}, {})
if value is not None:
output.write(str(value))
return output.getvalue()
def exec_with_return(code: str, globals: dict, locals: dict):
a = ast.parse(code)
last_expression = None
if a.body:
if isinstance(a_last := a.body[-1], ast.Expr):
last_expression = ast.unparse(a.body.pop())
elif isinstance(a_last, ast.Assign):
last_expression = ast.unparse(a_last.targets[0])
elif isinstance(a_last, (ast.AnnAssign, ast.AugAssign)):
last_expression = ast.unparse(a_last.target)
exec(ast.unparse(a), globals, locals)
if last_expression:
return eval(last_expression, globals, locals)
"#;
let declarations = parse_source(source, "execute_py_code", Path::new("tools")).unwrap();
assert_eq!(declarations.len(), 1);
let decl = &declarations[0];
assert_eq!(decl.name, "execute_py_code");
assert!(!decl.agent);
let params = &decl.parameters;
assert_eq!(properties(params).len(), 1);
let code = property(params, "code");
assert_eq!(code.type_value.as_deref(), Some("string"));
assert_eq!(
code.description.as_deref(),
Some("The Python code to execute, such as `print(\"hello world\")`")
);
}
#[test]
fn test_agent_tools() {
let source = r#"
import urllib.request
def get_ipinfo():
"""
Get the ip info
"""
with urllib.request.urlopen("https://httpbin.org/ip") as response:
data = response.read()
return data.decode('utf-8')
"#;
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
assert_eq!(declarations.len(), 1);
let decl = &declarations[0];
assert_eq!(decl.name, "get_ipinfo");
assert!(decl.agent);
assert_eq!(decl.description, "Get the ip info");
assert!(properties(&decl.parameters).is_empty());
}
#[test]
fn test_reject_varargs() {
let source = r#"
def run(*args):
"""Has docstring"""
return args
"#;
let err = parse_source(source, "reject_varargs", Path::new("tools")).unwrap_err();
assert!(err
.to_string()
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
}
#[test]
fn test_reject_kwargs() {
let source = r#"
def run(**kwargs):
"""Has docstring"""
return kwargs
"#;
let err = parse_source(source, "reject_kwargs", Path::new("tools")).unwrap_err();
assert!(err
.to_string()
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
}
#[test]
fn test_reject_positional_only() {
let source = r#"
def run(x, /, y):
"""Has docstring"""
return x + y
"#;
let err = parse_source(source, "reject_positional_only", Path::new("tools")).unwrap_err();
assert!(err
.to_string()
.contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions"));
}
#[test]
fn test_missing_docstring() {
let source = r#"
def run(x: str):
pass
"#;
let err = parse_source(source, "missing_docstring", Path::new("tools")).unwrap_err();
assert!(
err.to_string()
.contains("Missing or empty description on function: run")
);
}
#[test]
fn test_syntax_error() {
let source = "def run(: broken";
let err = parse_source(source, "syntax_error", Path::new("tools")).unwrap_err();
assert!(err.to_string().contains("failed to parse python"));
}
#[test]
fn test_underscore_functions_skipped() {
let source = r#"
def _private():
"""Private"""
return None
def public():
"""Public"""
return None
"#;
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
assert_eq!(declarations.len(), 1);
assert_eq!(declarations[0].name, "public");
}
#[test]
fn test_instructions_not_skipped() {
let source = r#"
def _instructions():
"""Help text"""
return None
"#;
let declarations = parse_source(source, "tools", Path::new("demo")).unwrap();
assert_eq!(declarations.len(), 1);
assert_eq!(declarations[0].name, "instructions");
assert_eq!(declarations[0].description, "Help text");
assert!(declarations[0].agent);
}
}