From f865892c2831d08ccddc5c42182550432158b757 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 9 Apr 2026 13:16:35 -0600 Subject: [PATCH] refactor: Extracted common Python parser logic into a common.rs module --- src/parsers/common.rs | 269 +++++++++++++++++++++++++ src/parsers/mod.rs | 1 + src/parsers/python.rs | 443 +++++++++++++----------------------------- 3 files changed, 410 insertions(+), 303 deletions(-) create mode 100644 src/parsers/common.rs diff --git a/src/parsers/common.rs b/src/parsers/common.rs new file mode 100644 index 0000000..b0d80d6 --- /dev/null +++ b/src/parsers/common.rs @@ -0,0 +1,269 @@ +use crate::function::{FunctionDeclaration, JsonSchema}; +use anyhow::{Context, Result, anyhow, bail}; +use indexmap::IndexMap; +use serde_json::Value; +use tree_sitter::Node; + +#[derive(Debug)] +pub(crate) struct Param { + pub name: String, + pub ty_hint: String, + pub required: bool, + pub default: Option, + pub doc_type: Option, + pub doc_desc: Option, +} + +pub(crate) trait ScriptedLanguage { + fn ts_language(&self) -> tree_sitter::Language; + + fn default_runtime(&self) -> &str; + + fn lang_name(&self) -> &str; + + fn find_functions<'a>( + &self, + root: Node<'a>, + src: &str, + ) -> Vec<(Node<'a>, Node<'a>)>; + + fn function_name<'a>(&self, func_node: Node<'a>, src: &'a str) -> Result<&'a str>; + + fn extract_description( + &self, + wrapper_node: Node<'_>, + func_node: Node<'_>, + src: &str, + ) -> Option; + + fn extract_params( + &self, + func_node: Node<'_>, + src: &str, + description: &str, + ) -> Result>; +} + +pub(crate) fn build_param( + name: &str, + mut ty: String, + mut required: bool, + default: Option, +) -> Param { + if ty.ends_with('?') { + ty.pop(); + required = false; + } + + Param { + name: name.to_string(), + ty_hint: ty, + required, + default, + doc_type: None, + doc_desc: None, + } +} + +pub(crate) fn build_parameters_schema(params: &[Param], _description: &str) -> JsonSchema { + let mut props: IndexMap = IndexMap::new(); + let mut req: Vec = Vec::new(); + + for p in params { + let name = p.name.replace('-', "_"); + let mut schema = JsonSchema::default(); + + let ty = if !p.ty_hint.is_empty() { + p.ty_hint.as_str() + } else if let Some(t) = &p.doc_type { + t.as_str() + } else { + "str" + }; + + if let Some(d) = &p.doc_desc + && !d.is_empty() + { + schema.description = Some(d.clone()); + } + + apply_type_to_schema(ty, &mut schema); + + if p.default.is_none() && p.required { + req.push(name.clone()); + } + + props.insert(name, schema); + } + + JsonSchema { + type_value: Some("object".into()), + description: None, + properties: Some(props), + items: None, + any_of: None, + enum_value: None, + default: None, + required: if req.is_empty() { None } else { Some(req) }, + } +} + +pub(crate) fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) { + let t = ty.trim_end_matches('?'); + if let Some(rest) = t.strip_prefix("list[") { + s.type_value = Some("array".into()); + let inner = rest.trim_end_matches(']'); + let mut item = JsonSchema::default(); + + apply_type_to_schema(inner, &mut item); + + if item.type_value.is_none() { + item.type_value = Some("string".into()); + } + s.items = Some(Box::new(item)); + + return; + } + + if let Some(rest) = t.strip_prefix("literal:") { + s.type_value = Some("string".into()); + let vals = rest + .split('|') + .map(|x| x.trim().trim_matches('"').trim_matches('\'').to_string()) + .collect::>(); + if !vals.is_empty() { + s.enum_value = Some(vals); + } + return; + } + + s.type_value = Some( + match t { + "bool" => "boolean", + "int" => "integer", + "float" => "number", + "str" | "any" | "" => "string", + _ => "string", + } + .into(), + ); +} + +pub(crate) fn underscore(s: &str) -> String { + s.chars() + .map(|c| { + if c.is_ascii_alphanumeric() { + c.to_ascii_lowercase() + } else { + '_' + } + }) + .collect::() + .split('_') + .filter(|t| !t.is_empty()) + .collect::>() + .join("_") +} + +pub(crate) fn node_text<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { + node.utf8_text(src.as_bytes()) + .map_err(|err| anyhow!("invalid utf-8 in source: {err}")) +} + +pub(crate) fn named_child(node: Node<'_>, index: usize) -> Option> { + let mut cursor = node.walk(); + node.named_children(&mut cursor).nth(index) +} + +pub(crate) fn extract_runtime(tree: &tree_sitter::Tree, src: &str, default: &str) -> String { + let root = tree.root_node(); + let mut cursor = root.walk(); + for child in root.named_children(&mut cursor) { + let text = match child.kind() { + "hash_bang_line" | "comment" => match child.utf8_text(src.as_bytes()) { + Ok(t) => t, + Err(_) => continue, + }, + _ => break, + }; + + if let Some(cmd) = text.strip_prefix("#!") { + let cmd = cmd.trim(); + if let Some(after_env) = cmd.strip_prefix("/usr/bin/env ") { + return after_env.trim().to_string(); + } + return cmd.to_string(); + } + + break; + } + default.to_string() +} + +pub(crate) fn generate_declarations( + lang: &L, + src: &str, + file_name: &str, + is_tool: bool, +) -> Result> { + let mut parser = tree_sitter::Parser::new(); + let language = lang.ts_language(); + parser.set_language(&language).with_context(|| { + format!( + "failed to initialize {} tree-sitter parser", + lang.lang_name() + ) + })?; + + let tree = parser + .parse(src.as_bytes(), None) + .ok_or_else(|| anyhow!("failed to parse {}: {file_name}", lang.lang_name()))?; + + if tree.root_node().has_error() { + bail!( + "failed to parse {}: syntax error in {file_name}", + lang.lang_name() + ); + } + + let _runtime = extract_runtime(&tree, src, lang.default_runtime()); + + let mut out = Vec::new(); + for (wrapper, func) in lang.find_functions(tree.root_node(), src) { + let func_name = lang.function_name(func, src)?; + + if func_name.starts_with('_') && func_name != "_instructions" { + continue; + } + if is_tool && func_name != "run" { + continue; + } + + let description = lang + .extract_description(wrapper, func, src) + .unwrap_or_default(); + let params = lang + .extract_params(func, src, &description) + .with_context(|| format!("in function '{func_name}' in {file_name}"))?; + let schema = build_parameters_schema(¶ms, &description); + + let name = if is_tool && func_name == "run" { + underscore(file_name) + } else { + underscore(func_name) + }; + + let desc_trim = description.trim().to_string(); + if desc_trim.is_empty() { + bail!("Missing or empty description on function: {func_name}"); + } + + out.push(FunctionDeclaration { + name, + description: desc_trim, + parameters: schema, + agent: !is_tool, + }); + } + Ok(out) +} diff --git a/src/parsers/mod.rs b/src/parsers/mod.rs index 57ae15a..5b872a9 100644 --- a/src/parsers/mod.rs +++ b/src/parsers/mod.rs @@ -1,2 +1,3 @@ pub(crate) mod bash; +pub(crate) mod common; pub(crate) mod python; diff --git a/src/parsers/python.rs b/src/parsers/python.rs index ff411b9..d8eb169 100644 --- a/src/parsers/python.rs +++ b/src/parsers/python.rs @@ -1,20 +1,124 @@ -use crate::function::{FunctionDeclaration, JsonSchema}; +use crate::function::FunctionDeclaration; +use crate::parsers::common::{self, Param, ScriptedLanguage}; use anyhow::{Context, Result, anyhow, bail}; use indexmap::IndexMap; use serde_json::Value; use std::fs::File; use std::io::Read; use std::path::Path; -use tree_sitter::{Node, Parser, Tree}; +use tree_sitter::Node; -#[derive(Debug)] -struct Param { - name: String, - ty_hint: String, - required: bool, - default: Option, - doc_type: Option, - doc_desc: Option, +pub(crate) struct PythonLanguage; + +impl ScriptedLanguage for PythonLanguage { + fn ts_language(&self) -> tree_sitter::Language { + tree_sitter_python::LANGUAGE.into() + } + + fn default_runtime(&self) -> &str { + "python" + } + + fn lang_name(&self) -> &str { + "python" + } + + fn find_functions<'a>(&self, root: Node<'a>, _src: &str) -> Vec<(Node<'a>, Node<'a>)> { + let mut cursor = root.walk(); + root.named_children(&mut cursor) + .filter_map(|stmt| unwrap_function_definition(stmt).map(|fd| (stmt, fd))) + .collect() + } + + fn function_name<'a>(&self, func_node: Node<'a>, src: &'a str) -> Result<&'a str> { + let name_node = func_node + .child_by_field_name("name") + .ok_or_else(|| anyhow!("function_definition missing name"))?; + common::node_text(name_node, src) + } + + fn extract_description( + &self, + _wrapper_node: Node<'_>, + func_node: Node<'_>, + src: &str, + ) -> Option { + get_docstring_from_function(func_node, src) + } + + fn extract_params( + &self, + func_node: Node<'_>, + src: &str, + description: &str, + ) -> Result> { + let parameters = func_node + .child_by_field_name("parameters") + .ok_or_else(|| anyhow!("function_definition missing parameters"))?; + let mut out = Vec::new(); + let mut cursor = parameters.walk(); + + for param in parameters.named_children(&mut cursor) { + match param.kind() { + "identifier" => out.push(Param { + name: common::node_text(param, src)?.to_string(), + ty_hint: String::new(), + required: true, + default: None, + doc_type: None, + doc_desc: None, + }), + "typed_parameter" => out.push(common::build_param( + parameter_name(param, src)?, + get_arg_type(param.child_by_field_name("type"), src)?, + true, + None, + )), + "default_parameter" => out.push(common::build_param( + parameter_name(param, src)?, + String::new(), + false, + Some(Value::Null), + )), + "typed_default_parameter" => out.push(common::build_param( + parameter_name(param, src)?, + get_arg_type(param.child_by_field_name("type"), src)?, + false, + Some(Value::Null), + )), + "list_splat_pattern" | "dictionary_splat_pattern" | "positional_separator" => { + let line = param.start_position().row + 1; + bail!( + "line {line}: *args/*kwargs/positional-only parameters are not supported in tool functions" + ) + } + "keyword_separator" => continue, + other => { + let line = param.start_position().row + 1; + bail!("line {line}: unsupported parameter type: {other}") + } + } + } + + let meta = parse_docstring_args(description); + for p in &mut out { + if let Some((t, d)) = meta.get(&p.name) { + if !t.is_empty() { + p.doc_type = Some(t.clone()); + } + + if !d.is_empty() { + p.doc_desc = Some(d.clone()); + } + + if t.ends_with('?') { + p.required = false; + } + } + } + + Ok(out) + } } pub fn generate_python_declarations( @@ -26,80 +130,12 @@ pub fn generate_python_declarations( tool_file .read_to_string(&mut src) .with_context(|| format!("Failed to load script at '{tool_file:?}'"))?; - let tree = parse_tree(&src, file_name)?; let is_tool = parent .and_then(|p| p.file_name()) .is_some_and(|n| n == "tools"); - python_to_function_declarations(file_name, &src, &tree, is_tool) -} - -fn parse_tree(src: &str, filename: &str) -> Result { - let mut parser = Parser::new(); - let language = tree_sitter_python::LANGUAGE.into(); - parser - .set_language(&language) - .context("failed to initialize python tree-sitter parser")?; - - let tree = parser - .parse(src.as_bytes(), None) - .ok_or_else(|| anyhow!("failed to parse python: {filename}"))?; - - if tree.root_node().has_error() { - bail!("failed to parse python: syntax error in {filename}"); - } - - Ok(tree) -} - -fn python_to_function_declarations( - file_name: &str, - src: &str, - tree: &Tree, - is_tool: bool, -) -> Result> { - let mut out = Vec::new(); - let root = tree.root_node(); - let mut cursor = root.walk(); - - for stmt in root.named_children(&mut cursor) { - let Some(fd) = unwrap_function_definition(stmt) else { - continue; - }; - - let func_name = function_name(fd, src)?.to_string(); - - if func_name.starts_with('_') && func_name != "_instructions" { - continue; - } - - if is_tool && func_name != "run" { - continue; - } - - let description = get_docstring_from_function(fd, src).unwrap_or_default(); - let params = collect_params(fd, src)?; - let schema = build_parameters_schema(¶ms, &description); - let name = if is_tool && func_name == "run" { - underscore(file_name) - } else { - underscore(&func_name) - }; - let desc_trim = description.trim().to_string(); - if desc_trim.is_empty() { - bail!("Missing or empty description on function: {func_name}"); - } - - out.push(FunctionDeclaration { - name, - description: desc_trim, - parameters: schema, - agent: !is_tool, - }); - } - - Ok(out) + common::generate_declarations(&PythonLanguage, &src, file_name, is_tool) } fn unwrap_function_definition(node: Node<'_>) -> Option> { @@ -114,13 +150,6 @@ fn unwrap_function_definition(node: Node<'_>) -> Option> { } } -fn function_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { - let name_node = node - .child_by_field_name("name") - .ok_or_else(|| anyhow!("function_definition missing name"))?; - node_text(name_node, src) -} - fn get_docstring_from_function(node: Node<'_>, src: &str) -> Option { let body = node.child_by_field_name("body")?; let mut cursor = body.walk(); @@ -135,7 +164,7 @@ fn get_docstring_from_function(node: Node<'_>, src: &str) -> Option { return None; } - let text = node_text(expr, src).ok()?; + let text = common::node_text(expr, src).ok()?; strip_string_quotes(text) } @@ -171,99 +200,16 @@ fn strip_string_quotes(text: &str) -> Option { Some(literal[quote.len()..literal.len() - quote.len()].to_string()) } -fn collect_params(node: Node<'_>, src: &str) -> Result> { - let parameters = node - .child_by_field_name("parameters") - .ok_or_else(|| anyhow!("function_definition missing parameters"))?; - let mut out = Vec::new(); - let mut cursor = parameters.walk(); - - for param in parameters.named_children(&mut cursor) { - match param.kind() { - "identifier" => out.push(Param { - name: node_text(param, src)?.to_string(), - ty_hint: String::new(), - required: true, - default: None, - doc_type: None, - doc_desc: None, - }), - "typed_parameter" => out.push(build_param( - parameter_name(param, src)?, - get_arg_type(param.child_by_field_name("type"), src)?, - true, - None, - )), - "default_parameter" => out.push(build_param( - parameter_name(param, src)?, - String::new(), - false, - Some(Value::Null), - )), - "typed_default_parameter" => out.push(build_param( - parameter_name(param, src)?, - get_arg_type(param.child_by_field_name("type"), src)?, - false, - Some(Value::Null), - )), - "list_splat_pattern" | "dictionary_splat_pattern" | "positional_separator" => { - bail!( - "Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions" - ) - } - "keyword_separator" => continue, - other => bail!("Unsupported parameter type: {other}"), - } - } - - if let Some(doc) = get_docstring_from_function(node, src) { - let meta = parse_docstring_args(&doc); - for p in &mut out { - if let Some((t, d)) = meta.get(&p.name) { - if !t.is_empty() { - p.doc_type = Some(t.clone()); - } - - if !d.is_empty() { - p.doc_desc = Some(d.clone()); - } - - if t.ends_with('?') { - p.required = false; - } - } - } - } - - Ok(out) -} - -fn build_param(name: &str, mut ty: String, mut required: bool, default: Option) -> Param { - if ty.ends_with('?') { - ty.pop(); - required = false; - } - - Param { - name: name.to_string(), - ty_hint: ty, - required, - default, - doc_type: None, - doc_desc: None, - } -} - fn parameter_name<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { if let Some(name) = node.child_by_field_name("name") { - return node_text(name, src); + return common::node_text(name, src); } let mut cursor = node.walk(); node.named_children(&mut cursor) .find(|child| child.kind() == "identifier") .ok_or_else(|| anyhow!("parameter missing name")) - .and_then(|name| node_text(name, src)) + .and_then(|name| common::node_text(name, src)) } fn get_arg_type(annotation: Option>, src: &str) -> Result { @@ -272,14 +218,14 @@ fn get_arg_type(annotation: Option>, src: &str) -> Result { }; match annotation.kind() { - "type" => get_arg_type(named_child(annotation, 0), src), + "type" => get_arg_type(common::named_child(annotation, 0), src), "generic_type" => { let value = annotation .child_by_field_name("type") - .or_else(|| named_child(annotation, 0)) + .or_else(|| common::named_child(annotation, 0)) .ok_or_else(|| anyhow!("generic_type missing value"))?; let value_name = if value.kind() == "identifier" { - node_text(value, src)? + common::node_text(value, src)? } else { return Ok("any".to_string()); }; @@ -287,7 +233,7 @@ fn get_arg_type(annotation: Option>, src: &str) -> Result { let inner = annotation .child_by_field_name("type_parameter") .or_else(|| annotation.child_by_field_name("parameters")) - .or_else(|| named_child(annotation, 1)) + .or_else(|| common::named_child(annotation, 1)) .ok_or_else(|| anyhow!("generic_type missing inner type"))?; match value_name { @@ -300,14 +246,14 @@ fn get_arg_type(annotation: Option>, src: &str) -> Result { _ => Ok("any".to_string()), } } - "identifier" => Ok(node_text(annotation, src)?.to_string()), + "identifier" => Ok(common::node_text(annotation, src)?.to_string()), "subscript" => { let value = annotation .child_by_field_name("value") - .or_else(|| named_child(annotation, 0)) + .or_else(|| common::named_child(annotation, 0)) .ok_or_else(|| anyhow!("subscript missing value"))?; let value_name = if value.kind() == "identifier" { - node_text(value, src)? + common::node_text(value, src)? } else { return Ok("any".to_string()); }; @@ -315,7 +261,7 @@ fn get_arg_type(annotation: Option>, src: &str) -> Result { let inner = annotation .child_by_field_name("subscript") .or_else(|| annotation.child_by_field_name("slice")) - .or_else(|| named_child(annotation, 1)) + .or_else(|| common::named_child(annotation, 1)) .ok_or_else(|| anyhow!("subscript missing inner type"))?; match value_name { "Optional" => Ok(format!("{}?", get_arg_type(Some(inner), src)?)), @@ -333,7 +279,7 @@ fn get_arg_type(annotation: Option>, src: &str) -> Result { fn generic_inner_type(node: Node<'_>, src: &str) -> Result { if node.kind() == "type_parameter" { - return get_arg_type(named_child(node, 0), src); + return get_arg_type(common::named_child(node, 0), src); } get_arg_type(Some(node), src) @@ -342,7 +288,7 @@ fn generic_inner_type(node: Node<'_>, src: &str) -> Result { fn literal_members(node: Node<'_>, src: &str) -> Result> { if node.kind() == "type" { return literal_members( - named_child(node, 0).ok_or_else(|| anyhow!("type missing inner literal"))?, + common::named_child(node, 0).ok_or_else(|| anyhow!("type missing inner literal"))?, src, ); } @@ -367,25 +313,15 @@ fn literal_members(node: Node<'_>, src: &str) -> Result> { fn expr_to_str(node: Node<'_>, src: &str) -> Result { match node.kind() { "type" => expr_to_str( - named_child(node, 0).ok_or_else(|| anyhow!("type missing expression"))?, + common::named_child(node, 0).ok_or_else(|| anyhow!("type missing expression"))?, src, ), "string" | "integer" | "float" | "true" | "false" | "none" | "identifier" - | "unary_operator" => Ok(node_text(node, src)?.trim().to_string()), + | "unary_operator" => Ok(common::node_text(node, src)?.trim().to_string()), _ => Ok("any".to_string()), } } -fn named_child(node: Node<'_>, index: usize) -> Option> { - let mut cursor = node.walk(); - node.named_children(&mut cursor).nth(index) -} - -fn node_text<'a>(node: Node<'_>, src: &'a str) -> Result<&'a str> { - node.utf8_text(src.as_bytes()) - .map_err(|err| anyhow!("invalid utf-8 in python source: {err}")) -} - fn parse_docstring_args(doc: &str) -> IndexMap { let mut out = IndexMap::new(); let mut in_args = false; @@ -421,109 +357,10 @@ fn parse_docstring_args(doc: &str) -> IndexMap { out } -fn underscore(s: &str) -> String { - s.chars() - .map(|c| { - if c.is_ascii_alphanumeric() { - c.to_ascii_lowercase() - } else { - '_' - } - }) - .collect::() - .split('_') - .filter(|t| !t.is_empty()) - .collect::>() - .join("_") -} - -fn build_parameters_schema(params: &[Param], _description: &str) -> JsonSchema { - let mut props: IndexMap = IndexMap::new(); - let mut req: Vec = Vec::new(); - - for p in params { - let name = p.name.replace('-', "_"); - let mut schema = JsonSchema::default(); - - let ty = if !p.ty_hint.is_empty() { - p.ty_hint.as_str() - } else if let Some(t) = &p.doc_type { - t.as_str() - } else { - "str" - }; - - if let Some(d) = &p.doc_desc - && !d.is_empty() - { - schema.description = Some(d.clone()); - } - - apply_type_to_schema(ty, &mut schema); - - if p.default.is_none() && p.required { - req.push(name.clone()); - } - - props.insert(name, schema); - } - - JsonSchema { - type_value: Some("object".into()), - description: None, - properties: Some(props), - items: None, - any_of: None, - enum_value: None, - default: None, - required: if req.is_empty() { None } else { Some(req) }, - } -} - -fn apply_type_to_schema(ty: &str, s: &mut JsonSchema) { - let t = ty.trim_end_matches('?'); - if let Some(rest) = t.strip_prefix("list[") { - s.type_value = Some("array".into()); - let inner = rest.trim_end_matches(']'); - let mut item = JsonSchema::default(); - - apply_type_to_schema(inner, &mut item); - - if item.type_value.is_none() { - item.type_value = Some("string".into()); - } - s.items = Some(Box::new(item)); - - return; - } - - if let Some(rest) = t.strip_prefix("literal:") { - s.type_value = Some("string".into()); - let vals = rest - .split('|') - .map(|x| x.trim().trim_matches('"').trim_matches('\'').to_string()) - .collect::>(); - if !vals.is_empty() { - s.enum_value = Some(vals); - } - return; - } - - s.type_value = Some( - match t { - "bool" => "boolean", - "int" => "integer", - "float" => "number", - "str" | "any" | "" => "string", - _ => "string", - } - .into(), - ); -} - #[cfg(test)] mod tests { use super::*; + use crate::function::JsonSchema; use std::fs; use std::time::{SystemTime, UNIX_EPOCH}; @@ -844,9 +681,9 @@ def run(*args): "#; let err = parse_source(source, "reject_varargs", Path::new("tools")).unwrap_err(); - assert!(err - .to_string() - .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + let msg = format!("{err:#}"); + assert!(msg.contains("*args/*kwargs/positional-only parameters are not supported")); + assert!(msg.contains("in function 'run'")); } #[test] @@ -858,9 +695,9 @@ def run(**kwargs): "#; let err = parse_source(source, "reject_kwargs", Path::new("tools")).unwrap_err(); - assert!(err - .to_string() - .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + let msg = format!("{err:#}"); + assert!(msg.contains("*args/*kwargs/positional-only parameters are not supported")); + assert!(msg.contains("in function 'run'")); } #[test] @@ -872,9 +709,9 @@ def run(x, /, y): "#; let err = parse_source(source, "reject_positional_only", Path::new("tools")).unwrap_err(); - assert!(err - .to_string() - .contains("Unsupported parameter type: *args/*kwargs/positional-only parameters are not supported in tool functions")); + let msg = format!("{err:#}"); + assert!(msg.contains("*args/*kwargs/positional-only parameters are not supported")); + assert!(msg.contains("in function 'run'")); } #[test]