This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dd6b2cba2 Fix conflicting Oracle and SSH providers tests (#28337)
4dd6b2cba2 is described below

commit 4dd6b2cba27833de599513950965e382d92399c9
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Dec 13 19:05:50 2022 +0400

    Fix conflicting Oracle and SSH providers tests (#28337)
---
 tests/providers/oracle/operators/test_oracle.py | 65 ++++++-------------------
 tests/providers/ssh/operators/test_ssh.py       | 57 ++++------------------
 2 files changed, 25 insertions(+), 97 deletions(-)

diff --git a/tests/providers/oracle/operators/test_oracle.py 
b/tests/providers/oracle/operators/test_oracle.py
index debf920f79..a48e980b6d 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -16,53 +16,17 @@
 # under the License.
 from __future__ import annotations
 
+import re
 from random import randrange
 from unittest import mock
 
 import oracledb
-import pendulum
 import pytest
 
-from airflow.models import DAG, DagModel, DagRun, TaskInstance
+from airflow.models import TaskInstance
 from airflow.providers.common.sql.hooks.sql import fetch_all_handler
 from airflow.providers.oracle.hooks.oracle import OracleHook
 from airflow.providers.oracle.operators.oracle import OracleOperator, 
OracleStoredProcedureOperator
-from airflow.utils.session import create_session
-from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
-
-DEFAULT_DATE = datetime(2017, 1, 1)
-
-
-def create_context(task, persist_to_db=False, map_index=None):
-    if task.has_dag():
-        dag = task.dag
-    else:
-        dag = DAG(dag_id="dag", start_date=pendulum.now())
-        dag.add_task(task)
-    dag_run = DagRun(
-        run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
-        run_type=DagRunType.MANUAL,
-        dag_id=dag.dag_id,
-    )
-    task_instance = TaskInstance(task=task, run_id=dag_run.run_id)
-    task_instance.dag_run = dag_run
-    if map_index is not None:
-        task_instance.map_index = map_index
-    if persist_to_db:
-        with create_session() as session:
-            session.add(DagModel(dag_id=dag.dag_id))
-            session.add(dag_run)
-            session.add(task_instance)
-            session.commit()
-    return {
-        "dag": dag,
-        "ts": DEFAULT_DATE.isoformat(),
-        "task": task,
-        "ti": task_instance,
-        "task_instance": task_instance,
-        "run_id": "test",
-    }
 
 
 class TestOracleOperator:
@@ -120,21 +84,22 @@ class TestOracleStoredProcedureOperator:
         )
 
     @mock.patch.object(OracleHook, "callproc", autospec=OracleHook.callproc)
-    def test_push_oracle_exit_to_xcom(self, mock_callproc):
+    def test_push_oracle_exit_to_xcom(self, mock_callproc, request, dag_maker):
         # Test pulls the value previously pushed to xcom and checks if it's 
the same
         procedure = "test_push"
         oracle_conn_id = "oracle_default"
         parameters = {"parameter": "value"}
         task_id = "test_push"
         ora_exit_code = "%05d" % randrange(10**5)
-        task = OracleStoredProcedureOperator(
-            procedure=procedure, oracle_conn_id=oracle_conn_id, 
parameters=parameters, task_id=task_id
-        )
-        context = create_context(task, persist_to_db=True)
-        mock_callproc.side_effect = oracledb.DatabaseError(
-            "ORA-" + ora_exit_code + ": This is a five-digit ORA error code"
-        )
-        try:
-            task.execute(context=context)
-        except oracledb.DatabaseError:
-            assert task.xcom_pull(key="ORA", context=context, 
task_ids=[task_id])[0] == ora_exit_code
+        error = f"ORA-{ora_exit_code}: This is a five-digit ORA error code"
+        mock_callproc.side_effect = oracledb.DatabaseError(error)
+
+        with dag_maker(dag_id=f"dag_{request.node.name}"):
+            task = OracleStoredProcedureOperator(
+                procedure=procedure, oracle_conn_id=oracle_conn_id, 
parameters=parameters, task_id=task_id
+            )
+        dr = dag_maker.create_dagrun(run_id=task_id)
+        ti = TaskInstance(task=task, run_id=dr.run_id)
+        with pytest.raises(oracledb.DatabaseError, match=re.escape(error)):
+            ti.run()
+        assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code
diff --git a/tests/providers/ssh/operators/test_ssh.py 
b/tests/providers/ssh/operators/test_ssh.py
index 140389b798..9065df9a3d 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -20,17 +20,14 @@ from __future__ import annotations
 from random import randrange
 from unittest import mock
 
-import pendulum
 import pytest
 from paramiko.client import SSHClient
 
 from airflow.exceptions import AirflowException
-from airflow.models import DAG, DagModel, DagRun, TaskInstance
+from airflow.models import TaskInstance
 from airflow.providers.ssh.hooks.ssh import SSHHook
 from airflow.providers.ssh.operators.ssh import SSHOperator
-from airflow.utils.session import create_session
 from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars
 
 TEST_DAG_ID = "unit_tests_ssh_test_op"
@@ -44,37 +41,6 @@ COMMAND = "echo -n airflow"
 COMMAND_WITH_SUDO = "sudo " + COMMAND
 
 
-def create_context(task, persist_to_db=False, map_index=None):
-    if task.has_dag():
-        dag = task.dag
-    else:
-        dag = DAG(dag_id="dag", start_date=pendulum.now())
-        dag.add_task(task)
-    dag_run = DagRun(
-        run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
-        run_type=DagRunType.MANUAL,
-        dag_id=dag.dag_id,
-    )
-    task_instance = TaskInstance(task=task, run_id=dag_run.run_id)
-    task_instance.dag_run = dag_run
-    if map_index is not None:
-        task_instance.map_index = map_index
-    if persist_to_db:
-        with create_session() as session:
-            session.add(DagModel(dag_id=dag.dag_id))
-            session.add(dag_run)
-            session.add(task_instance)
-            session.commit()
-    return {
-        "dag": dag,
-        "ts": DEFAULT_DATE.isoformat(),
-        "task": task,
-        "ti": task_instance,
-        "task_instance": task_instance,
-        "run_id": "test",
-    }
-
-
 class SSHClientSideEffect:
     def __init__(self, hook):
         self.hook = hook
@@ -235,19 +201,16 @@ class TestSSHOperator:
         with pytest.raises(AirflowException, match="SSH operator error: exit 
status = 1"):
             task.execute(None)
 
-    def test_push_ssh_exit_to_xcom(self, context=None):
+    def test_push_ssh_exit_to_xcom(self, request, dag_maker):
         # Test pulls the value previously pushed to xcom and checks if it's 
the same
         command = "not_a_real_command"
         ssh_exit_code = randrange(0, 100)
-        task_push = SSHOperator(task_id="test_push", ssh_hook=self.hook, 
command=command)
-        task_context = create_context(task_push, persist_to_db=True)
         self.exec_ssh_client_command.return_value = (ssh_exit_code, b"", b"ssh 
output")
-        try:
-            task_push.execute(context=task_context)
-        except AirflowException:
-            pass
-        finally:
-            assert (
-                task_push.xcom_pull(key="ssh_exit", context=task_context, 
task_ids=["test_push"])[0]
-                == ssh_exit_code
-            )
+
+        with dag_maker(dag_id=f"dag_{request.node.name}"):
+            task = SSHOperator(task_id="push_xcom", ssh_hook=self.hook, 
command=command)
+        dr = dag_maker.create_dagrun(run_id="push_xcom")
+        ti = TaskInstance(task=task, run_id=dr.run_id)
+        with pytest.raises(AirflowException, match=f"SSH operator error: exit 
status = {ssh_exit_code}"):
+            ti.run()
+        assert ti.xcom_pull(task_ids=task.task_id, key="ssh_exit") == 
ssh_exit_code

Reply via email to