This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4dd6b2cba2 Fix conflicting Oracle and SSH providers tests (#28337)
4dd6b2cba2 is described below
commit 4dd6b2cba27833de599513950965e382d92399c9
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Dec 13 19:05:50 2022 +0400
Fix conflicting Oracle and SSH providers tests (#28337)
---
tests/providers/oracle/operators/test_oracle.py | 65 ++++++-------------------
tests/providers/ssh/operators/test_ssh.py | 57 ++++------------------
2 files changed, 25 insertions(+), 97 deletions(-)
diff --git a/tests/providers/oracle/operators/test_oracle.py
b/tests/providers/oracle/operators/test_oracle.py
index debf920f79..a48e980b6d 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -16,53 +16,17 @@
# under the License.
from __future__ import annotations
+import re
from random import randrange
from unittest import mock
import oracledb
-import pendulum
import pytest
-from airflow.models import DAG, DagModel, DagRun, TaskInstance
+from airflow.models import TaskInstance
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.oracle.hooks.oracle import OracleHook
from airflow.providers.oracle.operators.oracle import OracleOperator,
OracleStoredProcedureOperator
-from airflow.utils.session import create_session
-from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
-
-DEFAULT_DATE = datetime(2017, 1, 1)
-
-
-def create_context(task, persist_to_db=False, map_index=None):
- if task.has_dag():
- dag = task.dag
- else:
- dag = DAG(dag_id="dag", start_date=pendulum.now())
- dag.add_task(task)
- dag_run = DagRun(
- run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
- run_type=DagRunType.MANUAL,
- dag_id=dag.dag_id,
- )
- task_instance = TaskInstance(task=task, run_id=dag_run.run_id)
- task_instance.dag_run = dag_run
- if map_index is not None:
- task_instance.map_index = map_index
- if persist_to_db:
- with create_session() as session:
- session.add(DagModel(dag_id=dag.dag_id))
- session.add(dag_run)
- session.add(task_instance)
- session.commit()
- return {
- "dag": dag,
- "ts": DEFAULT_DATE.isoformat(),
- "task": task,
- "ti": task_instance,
- "task_instance": task_instance,
- "run_id": "test",
- }
class TestOracleOperator:
@@ -120,21 +84,22 @@ class TestOracleStoredProcedureOperator:
)
@mock.patch.object(OracleHook, "callproc", autospec=OracleHook.callproc)
- def test_push_oracle_exit_to_xcom(self, mock_callproc):
+ def test_push_oracle_exit_to_xcom(self, mock_callproc, request, dag_maker):
# Test pulls the value previously pushed to xcom and checks if it's
the same
procedure = "test_push"
oracle_conn_id = "oracle_default"
parameters = {"parameter": "value"}
task_id = "test_push"
ora_exit_code = "%05d" % randrange(10**5)
- task = OracleStoredProcedureOperator(
- procedure=procedure, oracle_conn_id=oracle_conn_id,
parameters=parameters, task_id=task_id
- )
- context = create_context(task, persist_to_db=True)
- mock_callproc.side_effect = oracledb.DatabaseError(
- "ORA-" + ora_exit_code + ": This is a five-digit ORA error code"
- )
- try:
- task.execute(context=context)
- except oracledb.DatabaseError:
- assert task.xcom_pull(key="ORA", context=context,
task_ids=[task_id])[0] == ora_exit_code
+ error = f"ORA-{ora_exit_code}: This is a five-digit ORA error code"
+ mock_callproc.side_effect = oracledb.DatabaseError(error)
+
+ with dag_maker(dag_id=f"dag_{request.node.name}"):
+ task = OracleStoredProcedureOperator(
+ procedure=procedure, oracle_conn_id=oracle_conn_id,
parameters=parameters, task_id=task_id
+ )
+ dr = dag_maker.create_dagrun(run_id=task_id)
+ ti = TaskInstance(task=task, run_id=dr.run_id)
+ with pytest.raises(oracledb.DatabaseError, match=re.escape(error)):
+ ti.run()
+ assert ti.xcom_pull(task_ids=task.task_id, key="ORA") == ora_exit_code
diff --git a/tests/providers/ssh/operators/test_ssh.py
b/tests/providers/ssh/operators/test_ssh.py
index 140389b798..9065df9a3d 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -20,17 +20,14 @@ from __future__ import annotations
from random import randrange
from unittest import mock
-import pendulum
import pytest
from paramiko.client import SSHClient
from airflow.exceptions import AirflowException
-from airflow.models import DAG, DagModel, DagRun, TaskInstance
+from airflow.models import TaskInstance
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
-from airflow.utils.session import create_session
from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
from tests.test_utils.config import conf_vars
TEST_DAG_ID = "unit_tests_ssh_test_op"
@@ -44,37 +41,6 @@ COMMAND = "echo -n airflow"
COMMAND_WITH_SUDO = "sudo " + COMMAND
-def create_context(task, persist_to_db=False, map_index=None):
- if task.has_dag():
- dag = task.dag
- else:
- dag = DAG(dag_id="dag", start_date=pendulum.now())
- dag.add_task(task)
- dag_run = DagRun(
- run_id=DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE),
- run_type=DagRunType.MANUAL,
- dag_id=dag.dag_id,
- )
- task_instance = TaskInstance(task=task, run_id=dag_run.run_id)
- task_instance.dag_run = dag_run
- if map_index is not None:
- task_instance.map_index = map_index
- if persist_to_db:
- with create_session() as session:
- session.add(DagModel(dag_id=dag.dag_id))
- session.add(dag_run)
- session.add(task_instance)
- session.commit()
- return {
- "dag": dag,
- "ts": DEFAULT_DATE.isoformat(),
- "task": task,
- "ti": task_instance,
- "task_instance": task_instance,
- "run_id": "test",
- }
-
-
class SSHClientSideEffect:
def __init__(self, hook):
self.hook = hook
@@ -235,19 +201,16 @@ class TestSSHOperator:
with pytest.raises(AirflowException, match="SSH operator error: exit
status = 1"):
task.execute(None)
- def test_push_ssh_exit_to_xcom(self, context=None):
+ def test_push_ssh_exit_to_xcom(self, request, dag_maker):
# Test pulls the value previously pushed to xcom and checks if it's
the same
command = "not_a_real_command"
ssh_exit_code = randrange(0, 100)
- task_push = SSHOperator(task_id="test_push", ssh_hook=self.hook,
command=command)
- task_context = create_context(task_push, persist_to_db=True)
self.exec_ssh_client_command.return_value = (ssh_exit_code, b"", b"ssh
output")
- try:
- task_push.execute(context=task_context)
- except AirflowException:
- pass
- finally:
- assert (
- task_push.xcom_pull(key="ssh_exit", context=task_context,
task_ids=["test_push"])[0]
- == ssh_exit_code
- )
+
+ with dag_maker(dag_id=f"dag_{request.node.name}"):
+ task = SSHOperator(task_id="push_xcom", ssh_hook=self.hook,
command=command)
+ dr = dag_maker.create_dagrun(run_id="push_xcom")
+ ti = TaskInstance(task=task, run_id=dr.run_id)
+ with pytest.raises(AirflowException, match=f"SSH operator error: exit
status = {ssh_exit_code}"):
+ ti.run()
+ assert ti.xcom_pull(task_ids=task.task_id, key="ssh_exit") ==
ssh_exit_code