diff --git a/Cargo.lock b/Cargo.lock index e2741b3..162b10d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -859,7 +859,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash 2.1.2", + "rustc-hash", "shlex", "syn", ] @@ -1388,12 +1388,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "crunchy" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" - [[package]] name = "crypto-common" version = "0.1.7" @@ -1518,34 +1512,13 @@ dependencies = [ "syn", ] -[[package]] -name = "derive_more" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" -dependencies = [ - "derive_more-impl 1.0.0", -] - [[package]] name = "derive_more" version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" dependencies = [ - "derive_more-impl 2.1.1", -] - -[[package]] -name = "derive_more-impl" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "unicode-xid", + "derive_more-impl", ] [[package]] @@ -2105,15 +2078,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "getopts" -version = "0.2.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" -dependencies = [ - "unicode-width", -] - [[package]] name = "getrandom" version = "0.2.17" @@ -2258,15 +2222,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", -] - [[package]] name = "hashbrown" version = "0.15.5" @@ -2822,18 +2777,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "is-macro" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "is-terminal" version = "0.4.17" @@ -2870,15 +2813,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -2987,12 +2921,6 @@ dependencies = [ "simple_asn1", ] -[[package]] -name = "lalrpop-util" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553" - [[package]] name = "lazy_static" version = "1.5.0" @@ -3021,12 +2949,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "libm" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" - [[package]] name = "libredox" version = "0.1.15" @@ -3171,8 +3093,6 @@ dependencies = [ "reqwest-eventsource", "rmcp", "rust-embed", - "rustpython-ast", - "rustpython-parser", "scraper", "serde", "serde_json", @@ -3188,6 +3108,8 @@ dependencies = [ "tokio", "tokio-graceful", "tokio-stream", + "tree-sitter", + "tree-sitter-python", "unicode-segmentation", "unicode-width", "url", @@ -3233,64 +3155,6 @@ dependencies = [ "libc", ] -[[package]] -name = "malachite" -version = "0.4.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5" -dependencies = [ - "malachite-base", - "malachite-nz", - "malachite-q", -] - -[[package]] -name = "malachite-base" -version = "0.4.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb" -dependencies = [ - "hashbrown 0.14.5", - "itertools 0.11.0", - "libm", - "ryu", -] - -[[package]] -name = "malachite-bigint" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8" -dependencies = [ - "derive_more 1.0.0", - "malachite", - "num-integer", - "num-traits", - "paste", -] - -[[package]] -name = "malachite-nz" -version = "0.4.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7" -dependencies = [ - "itertools 0.11.0", - "libm", - "malachite-base", -] - -[[package]] -name = "malachite-q" -version = "0.4.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191" -dependencies = [ - "itertools 0.11.0", - "malachite-base", - "malachite-nz", -] - [[package]] name = "markup5ever" version = "0.12.1" @@ -3945,12 +3809,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - [[package]] name = "pastey" version = "0.2.1" @@ -4293,7 +4151,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.2", + "rustc-hash", "rustls 0.23.37", "socket2 0.6.3", "thiserror 2.0.18", @@ -4313,7 +4171,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash 2.1.2", + "rustc-hash", "rustls 0.23.37", "rustls-pki-types", "slab", @@ -4364,8 +4222,6 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", - "rand_chacha 0.3.1", "rand_core 0.6.4", ] @@ -4375,20 +4231,10 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha 0.9.0", + "rand_chacha", "rand_core 0.9.5", ] -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", -] - [[package]] name = "rand_chacha" version = "0.9.0" @@ -4732,12 +4578,6 @@ version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.2" @@ -4851,63 +4691,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "rustpython-ast" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5" -dependencies = [ - "is-macro", - "malachite-bigint", - "rustpython-parser-core", - "static_assertions", -] - -[[package]] -name = "rustpython-parser" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb" -dependencies = [ - "anyhow", - "is-macro", - "itertools 0.11.0", - "lalrpop-util", - "log", - "malachite-bigint", - "num-traits", - "phf", - "phf_codegen", - "rustc-hash 1.1.0", - "rustpython-ast", - "rustpython-parser-core", - "tiny-keccak", - "unic-emoji-char", - "unic-ucd-ident", - "unicode_names2", -] - -[[package]] -name = "rustpython-parser-core" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad" -dependencies = [ - "is-macro", - "memchr", - "rustpython-parser-vendored", -] - -[[package]] -name = "rustpython-parser-vendored" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6" -dependencies = [ - "memchr", - "once_cell", -] - [[package]] name = "rustversion" version = "1.0.22" @@ -5388,12 +5171,6 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "stop-words" version = "0.9.0" @@ -5403,6 +5180,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + [[package]] name = "string_cache" version = "0.8.9" @@ -5739,15 +5522,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinystr" version = "0.8.3" @@ -6053,6 +5827,35 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tree-sitter" +version = "0.24.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5387dffa7ffc7d2dae12b50c6f7aab8ff79d6210147c6613561fc3d474c6f75" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "streaming-iterator", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" + +[[package]] +name = "tree-sitter-python" +version = "0.23.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d065aaa27f3aaceaf60c1f0e0ac09e1cb9eb8ed28e7bcdaa52129cffc7f4b04" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "tree_magic_mini" version = "3.2.2" @@ -6136,58 +5939,6 @@ dependencies = [ "syn", ] -[[package]] -name = "unic-char-property" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221" -dependencies = [ - "unic-char-range", -] - -[[package]] -name = "unic-char-range" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc" - -[[package]] -name = "unic-common" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc" - -[[package]] -name = "unic-emoji-char" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d" -dependencies = [ - "unic-char-property", - "unic-char-range", - "unic-ucd-version", -] - -[[package]] -name = "unic-ucd-ident" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987" -dependencies = [ - "unic-char-property", - "unic-char-range", - "unic-ucd-version", -] - -[[package]] -name = "unic-ucd-version" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4" -dependencies = [ - "unic-common", -] - [[package]] name = "unicase" version = "2.9.0" @@ -6224,28 +5975,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "unicode_names2" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd" -dependencies = [ - "phf", - "unicode_names2_generator", -] - -[[package]] -name = "unicode_names2_generator" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e" -dependencies = [ - "getopts", - "log", - "phf_codegen", - "rand 0.8.5", -] - [[package]] name = "universal-hash" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 837f646..8b2177b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,8 +91,8 @@ strum_macros = "0.27.2" indoc = "2.0.6" rmcp = { version = "0.16.0", features = ["client", "transport-child-process"] } num_cpus = "1.17.0" -rustpython-parser = "0.4.0" -rustpython-ast = "0.4.0" +tree-sitter = "0.24" +tree-sitter-python = "0.23" colored = "3.0.0" clap_complete = { version = "4.5.58", features = ["unstable-dynamic"] } gman = "0.3.0" diff --git a/src/parsers/python.rs b/src/parsers/python.rs index 147ae9b..ff411b9 100644 --- a/src/parsers/python.rs +++ b/src/parsers/python.rs @@ -1,13 +1,11 @@ use crate::function::{FunctionDeclaration, JsonSchema}; -use anyhow::{Context, Result, bail}; -use ast::{Stmt, StmtFunctionDef}; +use anyhow::{Context, Result, anyhow, bail}; use indexmap::IndexMap; -use rustpython_ast::{Constant, Expr, UnaryOp}; -use rustpython_parser::{Mode, ast}; use serde_json::Value; use std::fs::File; use std::io::Read; use std::path::Path; +use tree_sitter::{Node, Parser, Tree}; #[derive(Debug)] struct Param { @@ -28,178 +26,197 @@ pub fn generate_python_declarations( tool_file .read_to_string(&mut src) .with_context(|| format!("Failed to load script at '{tool_file:?}'"))?; - let suite = parse_suite(&src, file_name)?; + let tree = parse_tree(&src, file_name)?; let is_tool = parent .and_then(|p| p.file_name()) .is_some_and(|n| n == "tools"); - let mut declarations = python_to_function_declarations(file_name, &suite, is_tool)?; - if is_tool { - for d in &mut declarations { - d.agent = true; - } - } - - Ok(declarations) + python_to_function_declarations(file_name, &src, &tree, is_tool) } -fn parse_suite(src: &str, filename: &str) -> Result { - let mod_ast = - rustpython_parser::parse(src, Mode::Module, filename).context("failed to parse python")?; +fn parse_tree(src: &str, filename: &str) -> Result { + let mut parser = Parser::new(); + let language = tree_sitter_python::LANGUAGE.into(); + parser + .set_language(&language) + .context("failed to initialize python tree-sitter parser")?; - let suite = match mod_ast { - ast::Mod::Module(m) => m.body, - ast::Mod::Interactive(m) => m.body, - ast::Mod::Expression(_) => bail!("expected a module; got a single expression"), - _ => bail!("unexpected parse mode/AST variant"), - }; + let tree = parser + .parse(src.as_bytes(), None) + .ok_or_else(|| anyhow!("failed to parse python: {filename}"))?; - Ok(suite) + if tree.root_node().has_error() { + bail!("failed to parse python: syntax error in {filename}"); + } + + Ok(tree) } fn python_to_function_declarations( file_name: &str, - module: &ast::Suite, + src: &str, + tree: &Tree, is_tool: bool, ) -> Result> { let mut out = Vec::new(); + let root = tree.root_node(); + let mut cursor = root.walk(); - for stmt in module { - if let Stmt::FunctionDef(fd) = stmt { - let func_name = fd.name.to_string(); + for stmt in root.named_children(&mut cursor) { + let Some(fd) = unwrap_function_definition(stmt) else { + continue; + }; - if func_name.starts_with('_') && func_name != "_instructions" { - continue; - } + let func_name = function_name(fd, src)?.to_string(); - if is_tool && func_name != "run" { - continue; - } - - let description = get_docstring_from_body(&fd.body).unwrap_or_default(); - let params = collect_params(fd); - let schema = build_parameters_schema(¶ms, &description); - let name = if is_tool && func_name == "run" { - underscore(file_name) - } else { - underscore(&func_name) - }; - let desc_trim = description.trim().to_string(); - if desc_trim.is_empty() { - bail!("Missing or empty description on function: {func_name}"); - } - - out.push(FunctionDeclaration { - name, - description: desc_trim, - parameters: schema, - agent: !is_tool, - }); + if func_name.starts_with('_') && func_name != "_instructions" { + continue; } + + if is_tool && func_name != "run" { + continue; + } + + let description = get_docstring_from_function(fd, src).unwrap_or_default(); + let params = collect_params(fd, src)?; + let schema = build_parameters_schema(¶ms, &description); + let name = if is_tool && func_name == "run" { + underscore(file_name) + } else { + underscore(&func_name) + }; + let desc_trim = description.trim().to_string(); + if desc_trim.is_empty() { + bail!("Missing or empty description on function: {func_name}"); + } + + out.push(FunctionDeclaration { + name, + description: desc_trim, + parameters: schema, + agent: !is_tool, + }); } Ok(out) } -fn get_docstring_from_body(body: &[Stmt]) -> Option { - let first = body.first()?; - if let Stmt::Expr(expr_stmt) = first - && let Expr::Constant(constant) = &*expr_stmt.value - && let Constant::Str(s) = &constant.value - { - return Some(s.clone()); +fn unwrap_function_definition(node: Node<'_>) -> Option> { + match node.kind() { + "function_definition" => Some(node), + "decorated_definition" => { + let mut cursor = node.walk(); + node.named_children(&mut cursor) + .find(|child| child.kind() == "function_definition") + } + _ => None, } - None } -fn collect_params(fd: &StmtFunctionDef) -> Vec { +fn function_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { + let name_node = node + .child_by_field_name("name") + .ok_or_else(|| anyhow!("function_definition missing name"))?; + node_text(name_node, src) +} + +fn get_docstring_from_function(node: Node<'_>, src: &str) -> Option { + let body = node.child_by_field_name("body")?; + let mut cursor = body.walk(); + let first = body.named_children(&mut cursor).next()?; + if first.kind() != "expression_statement" { + return None; + } + + let mut expr_cursor = first.walk(); + let expr = first.named_children(&mut expr_cursor).next()?; + if expr.kind() != "string" { + return None; + } + + let text = node_text(expr, src).ok()?; + strip_string_quotes(text) +} + +fn strip_string_quotes(text: &str) -> Option { + let quote_offset = text + .char_indices() + .find_map(|(idx, ch)| (ch == '\'' || ch == '"').then_some(idx))?; + let prefix = &text[..quote_offset]; + if !prefix.chars().all(|ch| ch.is_ascii_alphabetic()) { + return None; + } + if prefix.chars().any(|ch| ch == 'f' || ch == 'F') { + return None; + } + + let literal = &text[quote_offset..]; + let quote = if literal.starts_with("\"\"\"") { + "\"\"\"" + } else if literal.starts_with("'''") { + "'''" + } else if literal.starts_with('"') { + "\"" + } else if literal.starts_with('\'') { + "'" + } else { + return None; + }; + + if literal.len() < quote.len() * 2 || !literal.ends_with(quote) { + return None; + } + + Some(literal[quote.len()..literal.len() - quote.len()].to_string()) +} + +fn collect_params(node: Node<'_>, src: &str) -> Result> { + let parameters = node + .child_by_field_name("parameters") + .ok_or_else(|| anyhow!("function_definition missing parameters"))?; let mut out = Vec::new(); + let mut cursor = parameters.walk(); - for a in fd.args.posonlyargs.iter().chain(fd.args.args.iter()) { - let name = a.def.arg.to_string(); - let mut ty = get_arg_type(a.def.annotation.as_deref()); - let mut required = a.default.is_none(); - - if ty.ends_with('?') { - ty.pop(); - required = false; + for param in parameters.named_children(&mut cursor) { + match param.kind() { + "identifier" => out.push(Param { + name: node_text(param, src)?.to_string(), + ty_hint: String::new(), + required: true, + default: None, + doc_type: None, + doc_desc: None, + }), + "typed_parameter" => out.push(build_param( + parameter_name(param, src)?, + get_arg_type(param.child_by_field_name("type"), src)?, + true, + None, + )), + "default_parameter" => out.push(build_param( + parameter_name(param, src)?, + String::new(), + false, + Some(Value::Null), + )), + "typed_default_parameter" => out.push(build_param( + parameter_name(param, src)?, + get_arg_type(param.child_by_field_name("type"), src)?, + false, + Some(Value::Null), + )), + "list_splat_pattern" | "dictionary_splat_pattern" | "positional_separator" => { + bail!( + "Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions" + ) + } + "keyword_separator" => continue, + other => bail!("Unsupported parameter type: {other}"), } - - let default = if a.default.is_some() { - Some(Value::Null) - } else { - None - }; - - out.push(Param { - name, - ty_hint: ty, - required, - default, - doc_type: None, - doc_desc: None, - }); } - for a in &fd.args.kwonlyargs { - let name = a.def.arg.to_string(); - let mut ty = get_arg_type(a.def.annotation.as_deref()); - let mut required = a.default.is_none(); - - if ty.ends_with('?') { - ty.pop(); - required = false; - } - - let default = if a.default.is_some() { - Some(Value::Null) - } else { - None - }; - - out.push(Param { - name, - ty_hint: ty, - required, - default, - doc_type: None, - doc_desc: None, - }); - } - - if let Some(vararg) = &fd.args.vararg { - let name = vararg.arg.to_string(); - let inner = get_arg_type(vararg.annotation.as_deref()); - let ty = if inner.is_empty() { - "list[str]".into() - } else { - format!("list[{inner}]") - }; - - out.push(Param { - name, - ty_hint: ty, - required: false, - default: None, - doc_type: None, - doc_desc: None, - }); - } - - if let Some(kwarg) = &fd.args.kwarg { - let name = kwarg.arg.to_string(); - out.push(Param { - name, - ty_hint: "object".into(), - required: false, - default: None, - doc_type: None, - doc_desc: None, - }); - } - - if let Some(doc) = get_docstring_from_body(&fd.body) { + if let Some(doc) = get_docstring_from_function(node, src) { let meta = parse_docstring_args(&doc); for p in &mut out { if let Some((t, d)) = meta.get(&p.name) { @@ -218,69 +235,155 @@ fn collect_params(fd: &StmtFunctionDef) -> Vec { } } - out + Ok(out) } -fn get_arg_type(annotation: Option<&Expr>) -> String { - match annotation { - None => "".to_string(), - Some(Expr::Name(n)) => n.id.to_string(), - Some(Expr::Subscript(sub)) => match &*sub.value { - Expr::Name(name) if &name.id == "Optional" => { - let inner = get_arg_type(Some(&sub.slice)); - format!("{inner}?") - } - Expr::Name(name) if &name.id == "List" => { - let inner = get_arg_type(Some(&sub.slice)); - format!("list[{inner}]") - } - Expr::Name(name) if &name.id == "Literal" => { - let vals = literal_members(&sub.slice); - format!("literal:{}", vals.join("|")) - } - _ => "any".to_string(), - }, - _ => "any".to_string(), +fn build_param(name: &str, mut ty: String, mut required: bool, default: Option) -> Param { + if ty.ends_with('?') { + ty.pop(); + required = false; + } + + Param { + name: name.to_string(), + ty_hint: ty, + required, + default, + doc_type: None, + doc_desc: None, } } -fn expr_to_str(e: &Expr) -> String { - match e { - Expr::Constant(c) => match &c.value { - Constant::Str(s) => s.clone(), - Constant::Int(i) => i.to_string(), - Constant::Float(f) => f.to_string(), - Constant::Bool(b) => b.to_string(), - Constant::None => "None".to_string(), - Constant::Ellipsis => "...".to_string(), - Constant::Bytes(b) => String::from_utf8_lossy(b).into_owned(), - Constant::Complex { real, imag } => format!("{real}+{imag}j"), - _ => "any".to_string(), - }, +fn parameter_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { + if let Some(name) = node.child_by_field_name("name") { + return node_text(name, src); + } - Expr::Name(n) => n.id.to_string(), + let mut cursor = node.walk(); + node.named_children(&mut cursor) + .find(|child| child.kind() == "identifier") + .ok_or_else(|| anyhow!("parameter missing name")) + .and_then(|name| node_text(name, src)) +} - Expr::UnaryOp(u) => { - if matches!(u.op, UnaryOp::USub) { - let inner = expr_to_str(&u.operand); - if inner.parse::().is_ok() || inner.chars().all(|c| c.is_ascii_digit()) { - return format!("-{inner}"); - } +fn get_arg_type(annotation: Option>, src: &str) -> Result { + let Some(annotation) = annotation else { + return Ok(String::new()); + }; + + match annotation.kind() { + "type" => get_arg_type(named_child(annotation, 0), src), + "generic_type" => { + let value = annotation + .child_by_field_name("type") + .or_else(|| named_child(annotation, 0)) + .ok_or_else(|| anyhow!("generic_type missing value"))?; + let value_name = if value.kind() == "identifier" { + node_text(value, src)? + } else { + return Ok("any".to_string()); + }; + + let inner = annotation + .child_by_field_name("type_parameter") + .or_else(|| annotation.child_by_field_name("parameters")) + .or_else(|| named_child(annotation, 1)) + .ok_or_else(|| anyhow!("generic_type missing inner type"))?; + + match value_name { + "Optional" => Ok(format!("{}?", generic_inner_type(inner, src)?)), + "List" => Ok(format!("list[{}]", generic_inner_type(inner, src)?)), + "Literal" => Ok(format!( + "literal:{}", + literal_members(inner, src)?.join("|") + )), + _ => Ok("any".to_string()), } - "any".to_string() } + "identifier" => Ok(node_text(annotation, src)?.to_string()), + "subscript" => { + let value = annotation + .child_by_field_name("value") + .or_else(|| named_child(annotation, 0)) + .ok_or_else(|| anyhow!("subscript missing value"))?; + let value_name = if value.kind() == "identifier" { + node_text(value, src)? + } else { + return Ok("any".to_string()); + }; - Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect::>().join(","), - - _ => "any".to_string(), + let inner = annotation + .child_by_field_name("subscript") + .or_else(|| annotation.child_by_field_name("slice")) + .or_else(|| named_child(annotation, 1)) + .ok_or_else(|| anyhow!("subscript missing inner type"))?; + match value_name { + "Optional" => Ok(format!("{}?", get_arg_type(Some(inner), src)?)), + "List" => Ok(format!("list[{}]", get_arg_type(Some(inner), src)?)), + "Literal" => Ok(format!( + "literal:{}", + literal_members(inner, src)?.join("|") + )), + _ => Ok("any".to_string()), + } + } + _ => Ok("any".to_string()), } } -fn literal_members(e: &Expr) -> Vec { - match e { - Expr::Tuple(t) => t.elts.iter().map(expr_to_str).collect(), - _ => vec![expr_to_str(e)], +fn generic_inner_type(node: Node<'_>, src: &str) -> Result { + if node.kind() == "type_parameter" { + return get_arg_type(named_child(node, 0), src); } + + get_arg_type(Some(node), src) +} + +fn literal_members(node: Node<'_>, src: &str) -> Result> { + if node.kind() == "type" { + return literal_members( + named_child(node, 0).ok_or_else(|| anyhow!("type missing inner literal"))?, + src, + ); + } + + if node.kind() == "tuple" || node.kind() == "type_parameter" { + let mut cursor = node.walk(); + let members = node + .named_children(&mut cursor) + .map(|child| expr_to_str(child, src)) + .collect::>>()?; + + return Ok(if members.is_empty() { + vec!["any".to_string()] + } else { + members + }); + } + + Ok(vec![expr_to_str(node, src)?]) +} + +fn expr_to_str(node: Node<'_>, src: &str) -> Result { + match node.kind() { + "type" => expr_to_str( + named_child(node, 0).ok_or_else(|| anyhow!("type missing expression"))?, + src, + ), + "string" | "integer" | "float" | "true" | "false" | "none" | "identifier" + | "unary_operator" => Ok(node_text(node, src)?.trim().to_string()), + _ => Ok("any".to_string()), + } +} + +fn named_child(node: Node<'_>, index: usize) -> Option> { + let mut cursor = node.walk(); + node.named_children(&mut cursor).nth(index) +} + +fn node_text<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { + node.utf8_text(src.as_bytes()) + .map_err(|err| anyhow!("invalid utf-8 in python source: {err}")) } fn parse_docstring_args(doc: &str) -> IndexMap { @@ -417,3 +520,413 @@ fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) { .into(), ); } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn parse_source( + source: &str, + file_name: &str, + parent: &Path, + ) -> Result> { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time went backwards") + .as_nanos(); + let path = std::env::temp_dir().join(format!("loki_python_parser_{file_name}_{unique}.py")); + fs::write(&path, source).expect("failed to write temp python source"); + let file = File::open(&path).expect("failed to open temp python source"); + let result = generate_python_declarations(file, file_name, Some(parent)); + let _ = fs::remove_file(&path); + result + } + + fn properties(schema: &JsonSchema) -> &IndexMap { + schema + .properties + .as_ref() + .expect("missing schema properties") + } + + fn property<'a>(schema: &'a JsonSchema, name: &str) -> &'a JsonSchema { + properties(schema) + .get(name) + .unwrap_or_else(|| panic!("missing property: {name}")) + } + + #[test] + fn test_tool_demo_py() { + let source = r#" +import os +from typing import List, Literal, Optional + +def run( + string: str, + string_enum: Literal["foo", "bar"], + boolean: bool, + integer: int, + number: float, + array: List[str], + string_optional: Optional[str] = None, + array_optional: Optional[List[str]] = None, +): + """Demonstrates how to create a tool using Python and how to use comments. + Args: + string: Define a required string property + string_enum: Define a required string property with enum + boolean: Define a required boolean property + integer: Define a required integer property + number: Define a required number property + array: Define a required string array property + string_optional: Define an optional string property + array_optional: Define an optional string array property + """ + output = f"""string: {string} +string_enum: {string_enum} +string_optional: {string_optional} +boolean: {boolean} +integer: {integer} +number: {number} +array: {array} +array_optional: {array_optional}""" + + for key, value in os.environ.items(): + if key.startswith("LLM_"): + output = f"{output}\n{key}: {value}" + + return output +"#; + + let declarations = parse_source(source, "demo_py", Path::new("tools")).unwrap(); + assert_eq!(declarations.len(), 1); + + let decl = &declarations[0]; + assert_eq!(decl.name, "demo_py"); + assert!(!decl.agent); + assert!(decl.description.starts_with("Demonstrates how to create")); + + let params = &decl.parameters; + assert_eq!(params.type_value.as_deref(), Some("object")); + assert_eq!( + params.required.as_ref().unwrap(), + &vec![ + "string".to_string(), + "string_enum".to_string(), + "boolean".to_string(), + "integer".to_string(), + "number".to_string(), + "array".to_string(), + ] + ); + + assert_eq!( + property(params, "string").type_value.as_deref(), + Some("string") + ); + + let string_enum = property(params, "string_enum"); + assert_eq!(string_enum.type_value.as_deref(), Some("string")); + assert_eq!( + string_enum.enum_value.as_ref().unwrap(), + &vec!["foo".to_string(), "bar".to_string()] + ); + + assert_eq!( + property(params, "boolean").type_value.as_deref(), + Some("boolean") + ); + assert_eq!( + property(params, "integer").type_value.as_deref(), + Some("integer") + ); + assert_eq!( + property(params, "number").type_value.as_deref(), + Some("number") + ); + + let array = property(params, "array"); + assert_eq!(array.type_value.as_deref(), Some("array")); + assert_eq!( + array.items.as_ref().unwrap().type_value.as_deref(), + Some("string") + ); + + let string_optional = property(params, "string_optional"); + assert_eq!(string_optional.type_value.as_deref(), Some("string")); + assert!( + !params + .required + .as_ref() + .unwrap() + .contains(&"string_optional".to_string()) + ); + + let array_optional = property(params, "array_optional"); + assert_eq!(array_optional.type_value.as_deref(), Some("array")); + assert_eq!( + array_optional.items.as_ref().unwrap().type_value.as_deref(), + Some("string") + ); + assert!( + !params + .required + .as_ref() + .unwrap() + .contains(&"array_optional".to_string()) + ); + } + + #[test] + fn test_tool_weather() { + let source = r#" +import os +from pathlib import Path +from typing import Optional +from urllib.parse import quote_plus +from urllib.request import urlopen + + +def run( + location: str, + llm_output: Optional[str] = None, +) -> str: + """Get the current weather in a given location + + Args: + location (str): The city and optionally the state or country (e.g., "London", "San Francisco, CA"). + + Returns: + str: A single-line formatted weather string from wttr.in (``format=4`` with metric units). + """ + url = f"https://wttr.in/{quote_plus(location)}?format=4&M" + + with urlopen(url, timeout=10) as resp: + weather = resp.read().decode("utf-8", errors="replace") + + dest = llm_output if llm_output is not None else os.environ.get("LLM_OUTPUT", "/dev/stdout") + + if dest not in {"-", "/dev/stdout"}: + path = Path(dest) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as fh: + fh.write(weather) + else: + pass + + return weather +"#; + + let declarations = parse_source(source, "get_current_weather", Path::new("tools")).unwrap(); + assert_eq!(declarations.len(), 1); + + let decl = &declarations[0]; + assert_eq!(decl.name, "get_current_weather"); + assert!(!decl.agent); + assert!( + decl.description + .starts_with("Get the current weather in a given location") + ); + + let params = &decl.parameters; + assert_eq!( + params.required.as_ref().unwrap(), + &vec!["location".to_string()] + ); + + let location = property(params, "location"); + assert_eq!(location.type_value.as_deref(), Some("string")); + assert_eq!( + location.description.as_deref(), + Some( + "The city and optionally the state or country (e.g., \"London\", \"San Francisco, CA\")." + ) + ); + + let llm_output = property(params, "llm_output"); + assert_eq!(llm_output.type_value.as_deref(), Some("string")); + assert!( + !params + .required + .as_ref() + .unwrap() + .contains(&"llm_output".to_string()) + ); + } + + #[test] + fn test_tool_execute_py_code() { + let source = r#" +import ast +import io +from contextlib import redirect_stdout + + +def run(code: str): + """Execute the given Python code. + Args: + code: The Python code to execute, such as `print("hello world")` + """ + output = io.StringIO() + with redirect_stdout(output): + value = exec_with_return(code, {}, {}) + + if value is not None: + output.write(str(value)) + + return output.getvalue() + + +def exec_with_return(code: str, globals: dict, locals: dict): + a = ast.parse(code) + last_expression = None + if a.body: + if isinstance(a_last := a.body[-1], ast.Expr): + last_expression = ast.unparse(a.body.pop()) + elif isinstance(a_last, ast.Assign): + last_expression = ast.unparse(a_last.targets[0]) + elif isinstance(a_last, (ast.AnnAssign, ast.AugAssign)): + last_expression = ast.unparse(a_last.target) + exec(ast.unparse(a), globals, locals) + if last_expression: + return eval(last_expression, globals, locals) +"#; + + let declarations = parse_source(source, "execute_py_code", Path::new("tools")).unwrap(); + assert_eq!(declarations.len(), 1); + + let decl = &declarations[0]; + assert_eq!(decl.name, "execute_py_code"); + assert!(!decl.agent); + + let params = &decl.parameters; + assert_eq!(properties(params).len(), 1); + let code = property(params, "code"); + assert_eq!(code.type_value.as_deref(), Some("string")); + assert_eq!( + code.description.as_deref(), + Some("The Python code to execute, such as `print(\"hello world\")`") + ); + } + + #[test] + fn test_agent_tools() { + let source = r#" +import urllib.request + +def get_ipinfo(): + """ + Get the ip info + """ + with urllib.request.urlopen("https://httpbin.org/ip") as response: + data = response.read() + return data.decode('utf-8') +"#; + + let declarations = parse_source(source, "tools", Path::new("demo")).unwrap(); + assert_eq!(declarations.len(), 1); + + let decl = &declarations[0]; + assert_eq!(decl.name, "get_ipinfo"); + assert!(decl.agent); + assert_eq!(decl.description, "Get the ip info"); + assert!(properties(&decl.parameters).is_empty()); + } + + #[test] + fn test_reject_varargs() { + let source = r#" +def run(*args): + """Has docstring""" + return args +"#; + + let err = parse_source(source, "reject_varargs", Path::new("tools")).unwrap_err(); + assert!(err + .to_string() + .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + } + + #[test] + fn test_reject_kwargs() { + let source = r#" +def run(**kwargs): + """Has docstring""" + return kwargs +"#; + + let err = parse_source(source, "reject_kwargs", Path::new("tools")).unwrap_err(); + assert!(err + .to_string() + .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + } + + #[test] + fn test_reject_positional_only() { + let source = r#" +def run(x, /, y): + """Has docstring""" + return x + y +"#; + + let err = parse_source(source, "reject_positional_only", Path::new("tools")).unwrap_err(); + assert!(err + .to_string() + .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + } + + #[test] + fn test_missing_docstring() { + let source = r#" +def run(x: str): + pass +"#; + + let err = parse_source(source, "missing_docstring", Path::new("tools")).unwrap_err(); + assert!( + err.to_string() + .contains("Missing or empty description on function: run") + ); + } + + #[test] + fn test_syntax_error() { + let source = "def run(: broken"; + let err = parse_source(source, "syntax_error", Path::new("tools")).unwrap_err(); + assert!(err.to_string().contains("failed to parse python")); + } + + #[test] + fn test_underscore_functions_skipped() { + let source = r#" +def _private(): + """Private""" + return None + +def public(): + """Public""" + return None +"#; + + let declarations = parse_source(source, "tools", Path::new("demo")).unwrap(); + assert_eq!(declarations.len(), 1); + assert_eq!(declarations[0].name, "public"); + } + + #[test] + fn test_instructions_not_skipped() { + let source = r#" +def _instructions(): + """Help text""" + return None +"#; + + let declarations = parse_source(source, "tools", Path::new("demo")).unwrap(); + assert_eq!(declarations.len(), 1); + assert_eq!(declarations[0].name, "instructions"); + assert_eq!(declarations[0].description, "Help text"); + assert!(declarations[0].agent); + } +}