This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-stable
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-stable by this push:
     new e333133  Fix Incorrect warning in upgrade check and error in reading 
file (#14344)
e333133 is described below

commit e33313302173b6bd872523ced233acc26ff9d6fe
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Feb 24 20:14:23 2021 +0100

    Fix Incorrect warning in upgrade check and error in reading file (#14344)
    
    Closes: #14340, #14258, #14243
---
 ...in_between_dag_and_operator_not_allowed_rule.py | 57 ++++++++++++----------
 airflow/upgrade/rules/db_api_functions.py          |  8 +--
 airflow/upgrade/rules/import_changes.py            | 21 +++++---
 ...in_between_dag_and_operator_not_allowed_rule.py | 27 +++++++++-
 tests/upgrade/rules/test_import_changes.py         | 30 +++++++++++-
 5 files changed, 100 insertions(+), 43 deletions(-)

diff --git 
a/airflow/upgrade/rules/chain_between_dag_and_operator_not_allowed_rule.py 
b/airflow/upgrade/rules/chain_between_dag_and_operator_not_allowed_rule.py
index 79b9044..3291c03 100644
--- a/airflow/upgrade/rules/chain_between_dag_and_operator_not_allowed_rule.py
+++ b/airflow/upgrade/rules/chain_between_dag_and_operator_not_allowed_rule.py
@@ -17,7 +17,7 @@
 from __future__ import absolute_import
 
 import re
-
+import os
 from airflow import conf
 from airflow.upgrade.rules.base_rule import BaseRule
 from airflow.utils.dag_processing import list_py_file_paths
@@ -37,36 +37,41 @@ class ChainBetweenDAGAndOperatorNotAllowedRule(BaseRule):
     def _check_file(self, file_path):
         problems = []
         with open(file_path, "r") as file_pointer:
-            lines = file_pointer.readlines()
-            python_space = r"\s*\\?\s*\n?\s*"
-            # Find all the dag variable names.
-            dag_vars = 
re.findall(r"([A-Za-z0-9_]+){}={}DAG\(".format(python_space, python_space),
-                                  "".join(lines))
-            history = ""
-            for line_number, line in enumerate(lines, 1):
-                # Someone could have put the bitshift operator on a different 
line than the dag they
-                # were using it on, so search for dag >> or << dag in all 
previous lines that did
-                # not contain a logged issue.
-                history += line
-                matches = [
-                    re.search(r"DAG\([^\)]+\){}>>".format(python_space), 
history),
-                    re.search(r"<<{}DAG\(".format(python_space), history)
-                ]
-                for dag_var in dag_vars:
-                    matches.extend([
-                        re.search(r"(\s|^){}{}>>".format(dag_var, 
python_space), history),
-                        re.search(r"<<\s*{}{}".format(python_space, dag_var), 
history),
-                    ])
-                if any(matches):
-                    problems.append(self._change_info(file_path, line_number))
-                    # If we found a problem, clear our history so we don't 
re-log the problem
-                    # on the next line.
-                    history = ""
+            try:
+                lines = file_pointer.readlines()
+
+                python_space = r"\s*\\?\s*\n?\s*"
+                # Find all the dag variable names.
+                dag_vars = 
re.findall(r"([A-Za-z0-9_]+){}={}DAG\(".format(python_space, python_space),
+                                      "".join(lines))
+                history = ""
+                for line_number, line in enumerate(lines, 1):
+                    # Someone could have put the bitshift operator on a 
different line than the dag they
+                    # were using it on, so search for dag >> or << dag in all 
previous lines that did
+                    # not contain a logged issue.
+                    history += line
+                    matches = [
+                        re.search(r"DAG\([^\)]+\){}>>".format(python_space), 
history),
+                        re.search(r"<<{}DAG\(".format(python_space), history)
+                    ]
+                    for dag_var in dag_vars:
+                        matches.extend([
+                            re.search(r"(\s|^){}{}>>".format(dag_var, 
python_space), history),
+                            re.search(r"<<\s*{}{}".format(python_space, 
dag_var), history),
+                        ])
+                    if any(matches):
+                        problems.append(self._change_info(file_path, 
line_number))
+                        # If we found a problem, clear our history so we don't 
re-log the problem
+                        # on the next line.
+                        history = ""
+            except UnicodeDecodeError:
+                problems.append("Unable to read python file 
{}".format(file_path))
         return problems
 
     def check(self):
         dag_folder = conf.get("core", "dags_folder")
         file_paths = list_py_file_paths(directory=dag_folder, 
include_examples=False)
+        file_paths = [file for file in file_paths if os.path.splitext(file)[1] 
== ".py"]
         problems = []
         for file_path in file_paths:
             problems.extend(self._check_file(file_path))
diff --git a/airflow/upgrade/rules/db_api_functions.py 
b/airflow/upgrade/rules/db_api_functions.py
index 1801c36..0ca730f 100644
--- a/airflow/upgrade/rules/db_api_functions.py
+++ b/airflow/upgrade/rules/db_api_functions.py
@@ -33,20 +33,16 @@ def check_run(cls):
     try:
         cls.__new__(cls).run("fake SQL")
         return return_error_string(cls, "run")
-    except NotImplementedError:
-        pass
     except Exception:
-        return return_error_string(cls, "run")
+        pass
 
 
 def check_get_records(cls):
     try:
         cls.__new__(cls).get_records("fake SQL")
         return return_error_string(cls, "get_records")
-    except NotImplementedError:
-        pass
     except Exception:
-        return return_error_string(cls, "get_records")
+        pass
 
 
 def return_error_string(cls, method):
diff --git a/airflow/upgrade/rules/import_changes.py 
b/airflow/upgrade/rules/import_changes.py
index 317549d..24fb6e9 100644
--- a/airflow/upgrade/rules/import_changes.py
+++ b/airflow/upgrade/rules/import_changes.py
@@ -17,7 +17,7 @@
 
 import itertools
 from typing import NamedTuple, Optional, List
-
+import os
 from cached_property import cached_property
 from packaging.version import Version
 
@@ -87,7 +87,7 @@ class ImportChangesRule(BaseRule):
     if current_airflow_version < Version("2.0.0"):
 
         def _filter_incompatible_renames(arg):
-            new_path = arg[1]
+            new_path = arg[0]
             return (
                 not new_path.startswith("airflow.operators")
                 and not new_path.startswith("airflow.sensors")
@@ -111,12 +111,16 @@ class ImportChangesRule(BaseRule):
         problems = []
         providers = set()
         with open(file_path, "r") as file:
-            content = file.read()
-            for change in ImportChangesRule.ALL_CHANGES:
-                if change.old_class in content:
-                    problems.append(change.info(file_path))
-                    if change.providers_package:
-                        providers.add(change.providers_package)
+            try:
+                content = file.read()
+
+                for change in ImportChangesRule.ALL_CHANGES:
+                    if change.old_class in content:
+                        problems.append(change.info(file_path))
+                        if change.providers_package:
+                            providers.add(change.providers_package)
+            except UnicodeDecodeError:
+                problems.append("Unable to read python file 
{}".format(file_path))
         return problems, providers
 
     @staticmethod
@@ -138,6 +142,7 @@ class ImportChangesRule(BaseRule):
     def check(self):
         dag_folder = conf.get("core", "dags_folder")
         files = list_py_file_paths(directory=dag_folder, 
include_examples=False)
+        files = [file for file in files if os.path.splitext(file)[1] == ".py"]
         problems = []
         providers = set()
         # Split in to two groups - install backports first, then make changes
diff --git 
a/tests/upgrade/rules/test_chain_between_dag_and_operator_not_allowed_rule.py 
b/tests/upgrade/rules/test_chain_between_dag_and_operator_not_allowed_rule.py
index eba53c7..39ca7d1 100644
--- 
a/tests/upgrade/rules/test_chain_between_dag_and_operator_not_allowed_rule.py
+++ 
b/tests/upgrade/rules/test_chain_between_dag_and_operator_not_allowed_rule.py
@@ -25,8 +25,8 @@ from 
airflow.upgrade.rules.chain_between_dag_and_operator_not_allowed_rule impor
 
 
 @contextmanager
-def create_temp_file(mock_list_files, lines):
-    with NamedTemporaryFile("w+") as temp_file:
+def create_temp_file(mock_list_files, lines, extension=".py"):
+    with NamedTemporaryFile("w+", suffix=extension) as temp_file:
         mock_list_files.return_value = [temp_file.name]
         temp_file.writelines("\n".join(lines))
         temp_file.flush()
@@ -110,3 +110,26 @@ class 
TestChainBetweenDAGAndOperatorNotAllowedRule(TestCase):
             msgs = rule.check()
             expected_messages = [self.msg_template.format(rule.title, 
temp_file.name, 6)]
             assert expected_messages == msgs
+
+    def test_non_py_files_are_ignored(self, mock_list_files):
+        lines = ["dag = \\",
+                 "    DAG('my_dag')",
+                 "dummy = DummyOperator(task_id='dummy')",
+                 "",
+                 "dummy << \\",
+                 "dag"]
+
+        with create_temp_file(mock_list_files, lines, extension=".txt"):
+            rule = ChainBetweenDAGAndOperatorNotAllowedRule()
+            msgs = rule.check()
+            assert msgs == []
+
+    def test_decode_errors_are_handled(self, mock_list_files):
+
+        with NamedTemporaryFile("wb+", suffix=".py") as temp_file:
+            mock_list_files.return_value = [temp_file.name]
+            temp_file.write(b"    DAG('my_dag') \x03\x96")
+            temp_file.flush()
+            rule = ChainBetweenDAGAndOperatorNotAllowedRule()
+            msgs = rule.check()
+            assert msgs[0] == "Unable to read python file 
{}".format(temp_file.name)
diff --git a/tests/upgrade/rules/test_import_changes.py 
b/tests/upgrade/rules/test_import_changes.py
index 0b82485..b76ba33 100644
--- a/tests/upgrade/rules/test_import_changes.py
+++ b/tests/upgrade/rules/test_import_changes.py
@@ -51,7 +51,7 @@ class TestImportChangesRule:
         [ImportChange.from_new_old_paths(NEW_PATH, OLD_PATH)],
     )
     def test_check(self, mock_list_files):
-        with NamedTemporaryFile("w+") as temp:
+        with NamedTemporaryFile("w+", suffix=".py") as temp:
             mock_list_files.return_value = [temp.name]
 
             temp.write("from airflow.contrib import %s" % OLD_CLASS)
@@ -65,3 +65,31 @@ class TestImportChangesRule:
         assert temp.name in msg
         assert OLD_PATH in msg
         assert OLD_CLASS in msg
+
+    @mock.patch("airflow.upgrade.rules.import_changes.list_py_file_paths")
+    @mock.patch(
+        "airflow.upgrade.rules.import_changes.ImportChangesRule.ALL_CHANGES",
+        [ImportChange.from_new_old_paths(NEW_PATH, OLD_PATH)],
+    )
+    def test_non_py_files_are_ignored(self, mock_list_files):
+        with NamedTemporaryFile("w+", suffix=".txt") as temp:
+            mock_list_files.return_value = [temp.name]
+
+            temp.write("from airflow.contrib import %s" % OLD_CLASS)
+            temp.flush()
+            msgs = list(ImportChangesRule().check())
+        assert len(msgs) == 0
+
+    @mock.patch("airflow.upgrade.rules.import_changes.list_py_file_paths")
+    @mock.patch(
+        "airflow.upgrade.rules.import_changes.ImportChangesRule.ALL_CHANGES",
+        [ImportChange.from_new_old_paths(NEW_PATH, OLD_PATH)],
+    )
+    def test_decode_error_are_handled(self, mock_list_files):
+        with NamedTemporaryFile("wb+", suffix=".py") as temp:
+            mock_list_files.return_value = [temp.name]
+
+            temp.write(b"from airflow \x03\x96")
+            temp.flush()
+            msgs = list(ImportChangesRule().check())
+        assert msgs[0] == "Unable to read python file {}".format(temp.name)

Reply via email to