This is an automated email from the ASF dual-hosted git repository. wenjin272 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 065d0971f568c556e00525b7742788130844bfce Author: WenjinXie <[email protected]> AuthorDate: Wed May 6 17:51:51 2026 +0800 [plan][java] Add bash tool in java. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> --- plan/pom.xml | 11 + .../flink/agents/plan/tools/bash/BashTool.java | 163 ++++++++++++++ .../agents/plan/tools/bash/BashValidator.java | 244 +++++++++++++++++++++ .../flink/agents/plan/tools/bash/BashToolTest.java | 90 ++++++++ .../agents/plan/tools/bash/BashValidatorTest.java | 116 ++++++++++ 5 files changed, 624 insertions(+) diff --git a/plan/pom.xml b/plan/pom.xml index ede26e08..02df3c2c 100644 --- a/plan/pom.xml +++ b/plan/pom.xml @@ -50,6 +50,17 @@ under the License. <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> </dependency> + <!-- Tree-sitter bash grammar bindings for BashValidator --> + <dependency> + <groupId>io.github.bonede</groupId> + <artifactId>tree-sitter</artifactId> + <version>0.25.3</version> + </dependency> + <dependency> + <groupId>io.github.bonede</groupId> + <artifactId>tree-sitter-bash</artifactId> + <version>0.23.3</version> + </dependency> <dependency> <groupId>com.alibaba</groupId> <artifactId>pemja</artifactId> diff --git a/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java new file mode 100644 index 00000000..57f1dc08 --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +/** + * Standalone bash execution tool. + * + * <p>Mirrors the Python {@code flink_agents.plan.tools.bash.bash_tool.BashTool}. The framework + * (e.g. {@code ChatModelAction}) injects {@code allowed_commands} and {@code allowed_script_dirs} + * at call time; the model only sees {@code command}, {@code timeout} and {@code cwd}. + */ +public class BashTool extends Tool { + + private static final String DESCRIPTION = + "Execute a shell command. Only commands on the allowed list or scripts under the allowed directories may run."; + + private static final String INPUT_SCHEMA = + "{\"type\":\"object\"," + + "\"properties\":{" + + "\"command\":{\"type\":\"string\",\"description\":\"The shell command to execute.\"}," + + "\"timeout\":{\"type\":\"integer\",\"description\":\"Timeout in seconds. Defaults to 60.\",\"default\":60}," + + "\"cwd\":{\"type\":\"string\",\"description\":\"The working directory to run the command in. Defaults to the current directory. Use this instead of `cd` commands.\"}" + + "}," + + "\"required\":[\"command\"]}"; + + public BashTool(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(new ToolMetadata("bash", DESCRIPTION, INPUT_SCHEMA)); + this.resourceContext = resourceContext; + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + @SuppressWarnings("unchecked") + List<String> allowedCommands = + parameters.hasParameter("allowed_commands") + ? (List<String>) parameters.getParameter("allowed_commands") + : Collections.emptyList(); + @SuppressWarnings("unchecked") + List<String> allowedScriptDirs = + parameters.hasParameter("allowed_script_dirs") + ? (List<String>) parameters.getParameter("allowed_script_dirs") + : Collections.emptyList(); + + String command = parameters.getParameter("command", String.class); + int timeout = + parameters.hasParameter("timeout") + ? parameters.getParameter("timeout", Integer.class) + : 60; + String cwd = + parameters.hasParameter("cwd") + ? parameters.getParameter("cwd", String.class) + : null; + + if (cwd != null && !BashValidator.isUnderAllowedDirs(cwd, allowedScriptDirs, null)) { + List<String> sorted = new ArrayList<>(allowedScriptDirs); + Collections.sort(sorted); + return ToolResponse.success( + "Command rejected: cwd '" + + cwd + + "' is not under any allowed script dir. Allowed script dirs: " + + sorted + + "."); + } + + Optional<String> error = + BashValidator.validate(command, allowedCommands, allowedScriptDirs, cwd); + if (error.isPresent()) { + return ToolResponse.success("Command rejected: " + error.get()); + } + + try { + ProcessBuilder pb = new ProcessBuilder("bash", "-c", command); + if (cwd != null) { + pb.directory(new File(cwd)); + } + Process process = pb.start(); + ByteArrayOutputStream stdout = new ByteArrayOutputStream(); + ByteArrayOutputStream stderr = new ByteArrayOutputStream(); + // Drain output streams to avoid blocking on pipe buffer fill. + Thread tOut = drainAsync(process.getInputStream(), stdout); + Thread tErr = drainAsync(process.getErrorStream(), stderr); + boolean finished = process.waitFor(timeout, TimeUnit.SECONDS); + if (!finished) { + process.destroyForcibly(); + return ToolResponse.success( + "Error: Command timed out after " + timeout + " seconds"); + } + tOut.join(); + tErr.join(); + int exit = process.exitValue(); + String stdoutStr = new String(stdout.toByteArray(), StandardCharsets.UTF_8).strip(); + String stderrStr = new String(stderr.toByteArray(), StandardCharsets.UTF_8).strip(); + if (exit == 0) { + return ToolResponse.success(stdoutStr.isEmpty() ? "Success" : stdoutStr); + } + return ToolResponse.success("Error (exit code " + exit + "): " + stderrStr); + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + return ToolResponse.success("Error: " + e.getMessage()); + } + } + + private static Thread drainAsync(InputStream stream, ByteArrayOutputStream sink) { + Thread t = + new Thread( + () -> { + try (InputStream in = stream) { + byte[] buf = new byte[4096]; + int n; + while ((n = in.read(buf)) > 0) { + sink.write(buf, 0, n); + } + } catch (IOException ignored) { + // process exit closes stream + } + }); + t.setDaemon(true); + t.start(); + return t; + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java new file mode 100644 index 00000000..2d938ecf --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.treesitter.TSNode; +import org.treesitter.TSParser; +import org.treesitter.TSTree; +import org.treesitter.TreeSitterBash; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** + * AST-based bash command validator backed by tree-sitter-bash. + * + * <p>Mirrors the Python {@code flink_agents.plan.tools.bash.bash_validator}: walks the parsed AST + * and rejects any named node whose type is not on the {@link #ALLOWED_NAMED} allowlist (e.g. {@code + * command_substitution}, {@code subshell}, {@code for_statement}, etc.); for every {@code command} + * node it requires the executable to be either in {@code allowedCommands} or to resolve to a path + * under one of {@code allowedScriptDirs}. + */ +public final class BashValidator { + + /** + * Named AST node types we accept. Anything named but missing from this set is treated as a + * potentially dangerous shell construct and rejected. Unnamed nodes (literal punctuation like + * {@code |}, {@code &&}, {@code (}) are always allowed — they're just syntax tokens. + * + * <p>Kept in sync with the Python {@code _ALLOWED_NAMED} set. + */ + public static final Set<String> ALLOWED_NAMED = + Set.of( + "program", + "command", + "command_name", + // `export VAR=...`, `readonly`, `declare`, `local`, `typeset` + "declaration_command", + "pipeline", + "list", + "redirected_statement", + "file_redirect", + "file_descriptor", + "variable_assignment", + "variable_name", + "special_variable_name", // $@ $? $* $# + "word", + "string", + "string_content", + "raw_string", + "ansi_c_string", + "translated_string", + "concatenation", + "number", + "simple_expansion", // $VAR + "expansion", // ${VAR} + "arithmetic_expansion", // $((...)) + "binary_expression", + "unary_expression", + "parenthesized_expression", + "array"); + + private static final Object PARSER_LOCK = new Object(); + private static volatile TSParser parser; + + private BashValidator() {} + + private static TSParser parser() { + TSParser p = parser; + if (p == null) { + synchronized (PARSER_LOCK) { + p = parser; + if (p == null) { + TSParser created = new TSParser(); + created.setLanguage(new TreeSitterBash()); + parser = created; + p = created; + } + } + } + return p; + } + + /** + * Validate a bash command. Returns {@link Optional#empty()} when allowed, or a non-empty + * descriptive error otherwise. + */ + public static Optional<String> validate( + String command, + List<String> allowedCommands, + List<String> allowedScriptDirs, + @Nullable String cwd) { + if (command == null || command.trim().isEmpty()) { + return Optional.of("Empty command."); + } + TSTree tree; + synchronized (PARSER_LOCK) { + tree = parser().parseString(null, command); + } + TSNode root = tree.getRootNode(); + if (root.hasError()) { + return Optional.of("Command has syntax errors."); + } + if (root.getChildCount() == 0) { + return Optional.of("Empty command."); + } + return walk(root, command, allowedCommands, allowedScriptDirs, cwd); + } + + private static Optional<String> walk( + TSNode node, + String command, + List<String> allowedCommands, + List<String> allowedScriptDirs, + @Nullable String cwd) { + if (node.isNamed() && !ALLOWED_NAMED.contains(node.getType())) { + String snippet = nodeText(node, command); + if (snippet.length() > 80) { + snippet = snippet.substring(0, 80); + } + return Optional.of( + "Disallowed shell construct '" + node.getType() + "' in: '" + snippet + "'"); + } + if ("command".equals(node.getType())) { + Optional<String> err = + validateCommand(node, command, allowedCommands, allowedScriptDirs, cwd); + if (err.isPresent()) { + return err; + } + } + for (int i = 0; i < node.getChildCount(); i++) { + Optional<String> err = + walk(node.getChild(i), command, allowedCommands, allowedScriptDirs, cwd); + if (err.isPresent()) { + return err; + } + } + return Optional.empty(); + } + + private static Optional<String> validateCommand( + TSNode commandNode, + String command, + List<String> allowedCommands, + List<String> allowedScriptDirs, + @Nullable String cwd) { + TSNode nameNode = commandNode.getChildByFieldName("name"); + if (nameNode == null || nameNode.isNull()) { + // Bare variable-assignment parsed as command — nothing to validate. + return Optional.empty(); + } + String executable = nodeText(nameNode, command); + if (allowedCommands.contains(executable)) { + return Optional.empty(); + } + if (isUnderAllowedDirs(executable, allowedScriptDirs, cwd)) { + return Optional.empty(); + } + Set<String> sortedCommands = new HashSet<>(allowedCommands); + Set<String> sortedDirs = new HashSet<>(allowedScriptDirs); + return Optional.of( + "Command '" + + executable + + "' is not allowed. Allowed commands: " + + sortedCommands + + ". Allowed script dirs: " + + sortedDirs + + "."); + } + + /** Return true when {@code pathStr} resolves to a path under any of the allowed dirs. */ + public static boolean isUnderAllowedDirs( + String pathStr, List<String> allowedDirs, @Nullable String cwd) { + Path base; + try { + base = Path.of(pathStr); + } catch (Exception e) { + return false; + } + if (!base.isAbsolute() && cwd != null) { + base = Path.of(cwd).resolve(base); + } + Path resolved; + try { + resolved = base.toAbsolutePath().toRealPath(); + } catch (IOException e) { + try { + resolved = base.toAbsolutePath().normalize(); + } catch (Exception ee) { + return false; + } + } catch (Exception e) { + return false; + } + for (String allowed : allowedDirs) { + try { + Path allowedRoot; + try { + allowedRoot = Path.of(allowed).toAbsolutePath().toRealPath(); + } catch (IOException io) { + allowedRoot = Path.of(allowed).toAbsolutePath().normalize(); + } + if (resolved.startsWith(allowedRoot)) { + return true; + } + } catch (Exception ignored) { + // skip invalid allowed root + } + } + return false; + } + + private static String nodeText(TSNode node, String command) { + byte[] bytes = command.getBytes(StandardCharsets.UTF_8); + int start = node.getStartByte(); + int end = node.getEndByte(); + if (start < 0 || end > bytes.length || start > end) { + return ""; + } + return new String(bytes, start, end - start, StandardCharsets.UTF_8); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java new file mode 100644 index 00000000..3a134923 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BashToolTest { + + private static BashTool tool() { + return new BashTool( + new ResourceDescriptor(BashTool.class.getName(), Map.of()), + ResourceContext.fromGetResource((n, t) -> null)); + } + + private static ToolParameters args( + String command, List<String> allowedCommands, List<String> allowedScriptDirs) { + Map<String, Object> m = new HashMap<>(); + m.put("command", command); + m.put("allowed_commands", allowedCommands); + m.put("allowed_script_dirs", allowedScriptDirs); + return new ToolParameters(m); + } + + @Test + void allowedSimpleCommandRuns() { + ToolResponse r = tool().call(args("echo hello", List.of("echo"), List.of())); + assertEquals("hello", r.getResult()); + } + + @Test + void disallowedCommandRejected() { + ToolResponse r = tool().call(args("rm -rf /", List.of("echo"), List.of())); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Command rejected:")); + assertTrue(out.contains("'rm' is not allowed")); + } + + @Test + void controlFlowRejected() { + ToolResponse r = + tool().call(args("for i in 1 2 3; do echo $i; done", List.of("echo"), List.of())); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Command rejected:")); + } + + @Test + void successfulCommandWithEmptyOutput() { + ToolResponse r = tool().call(args("true", List.of("true"), List.of())); + assertEquals("Success", r.getResult()); + } + + @Test + void timeoutEnforced() { + Map<String, Object> m = new HashMap<>(); + m.put("command", "sleep 5"); + m.put("allowed_commands", List.of("sleep")); + m.put("allowed_script_dirs", List.of()); + m.put("timeout", 1); + ToolResponse r = tool().call(new ToolParameters(m)); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Error: Command timed out")); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java new file mode 100644 index 00000000..7ff28f07 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BashValidatorTest { + + @Test + void emptyCommandRejected() { + assertEquals( + Optional.of("Empty command."), + BashValidator.validate("", List.of("echo"), List.of(), null)); + assertEquals( + Optional.of("Empty command."), + BashValidator.validate(" ", List.of("echo"), List.of(), null)); + } + + @Test + void simpleAllowedCommandPasses() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo hello", List.of("echo"), List.of(), null)); + } + + @Test + void unknownCommandRejected() { + Optional<String> r = BashValidator.validate("rm -rf /", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("'rm' is not allowed")); + } + + @Test + void pipelineAllowedWhenAllPartsAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate( + "echo hi | tr a-z A-Z", List.of("echo", "tr"), List.of(), null)); + } + + @Test + void pipelineRejectedWhenAnyPartUnknown() { + Optional<String> r = + BashValidator.validate("echo hi | grep h", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("'grep'")); + } + + @Test + void variableExpansionAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo $HOME", List.of("echo"), List.of(), null)); + } + + @Test + void arithmeticExpansionAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo $((1+2))", List.of("echo"), List.of(), null)); + } + + @Test + void commandSubstitutionRejected() { + Optional<String> r = + BashValidator.validate("echo $(rm /etc/passwd)", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("command_substitution")); + } + + @Test + void backticksRejected() { + Optional<String> r = + BashValidator.validate("echo `whoami`", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + } + + @Test + void controlFlowRejected() { + Optional<String> r = + BashValidator.validate( + "for i in 1 2 3; do echo $i; done", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("for_statement")); + } + + @Test + void redirectAllowed() { + // basic redirect of allowed command should pass + Optional<String> r = + BashValidator.validate("echo hi > /tmp/x", List.of("echo"), List.of(), null); + assertEquals(Optional.empty(), r); + } +}
