This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 67613f0c38b Replace models.BaseOperator to Task SDK one for Common
Providers (#52443)
67613f0c38b is described below
commit 67613f0c38be5b98f7d12d88a6b751d6e27590a6
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jun 30 03:06:20 2025 +0530
Replace models.BaseOperator to Task SDK one for Common Providers (#52443)
Part of https://github.com/apache/airflow/issues/52378
---
.../providers/common/compat/standard/operators.py | 3 +-
.../providers/common/compat/version_compat.py | 12 +++++
.../providers/common/io/operators/file_transfer.py | 4 +-
.../airflow/providers/common/io/version_compat.py | 10 ++++
.../common/sql/operators/generic_transfer.py | 4 +-
.../airflow/providers/common/sql/operators/sql.py | 3 +-
.../airflow/providers/common/sql/version_compat.py | 12 +++++
.../common/sql/operators/test_generic_transfer.py | 59 ++++++++++------------
.../tests/unit/common/sql/operators/test_sql.py | 57 +++++++++++++++------
.../unit/snowflake/operators/test_snowflake.py | 17 +++----
10 files changed, 118 insertions(+), 63 deletions(-)
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index c8838fe3c21..b3d35f1aa14 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -46,5 +46,6 @@ else:
except (ImportError, ModuleNotFoundError):
from airflow.providers.standard.operators.python import
get_current_context
+from airflow.providers.common.compat.version_compat import BaseOperator
-__all__ = ["PythonOperator", "_SERIALIZERS", "ShortCircuitOperator",
"get_current_context"]
+__all__ = ["BaseOperator", "PythonOperator", "_SERIALIZERS",
"ShortCircuitOperator", "get_current_context"]
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
index 48d122b6696..02d0f1ac162 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
@@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import BaseOperator
+else:
+ from airflow.models import BaseOperator
+
+__all__ = [
+ "AIRFLOW_V_3_0_PLUS",
+ "AIRFLOW_V_3_1_PLUS",
+ "BaseOperator",
+]
diff --git
a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
index 0faec858d1c..00b23bde989 100644
---
a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
+++
b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS,
BaseOperator
if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
@@ -28,10 +28,8 @@ if TYPE_CHECKING:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import ObjectStoragePath
- from airflow.sdk.bases.operator import BaseOperator
else:
from airflow.io.path import ObjectStoragePath # type: ignore[no-redef]
- from airflow.models import BaseOperator # type: ignore[no-redef]
class FileTransferOperator(BaseOperator):
diff --git
a/providers/common/io/src/airflow/providers/common/io/version_compat.py
b/providers/common/io/src/airflow/providers/common/io/version_compat.py
index 48d122b6696..e7a259afb35 100644
--- a/providers/common/io/src/airflow/providers/common/io/version_compat.py
+++ b/providers/common/io/src/airflow/providers/common/io/version_compat.py
@@ -33,3 +33,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import BaseOperator
+else:
+ from airflow.models import BaseOperator
+
+__all__ = [
+ "AIRFLOW_V_3_0_PLUS",
+ "BaseOperator",
+]
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
index c7839b28aec..4e3b0a87ce0 100644
---
a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
+++
b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
@@ -23,9 +23,9 @@ from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger
+from airflow.providers.common.sql.version_compat import BaseOperator
if TYPE_CHECKING:
import jinja2
@@ -192,7 +192,7 @@ class GenericTransfer(BaseOperator):
)
self.log.info("Offset increased to %d", offset)
- self.xcom_push(context=context, key="offset", value=offset)
+ context["ti"].xcom_push(key="offset", value=offset)
self.log.info("Inserting %d rows into %s", len(results),
self.destination_conn_id)
self.destination_hook.insert_rows(
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index af04349532a..250a249d5af 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -25,9 +25,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, NoReturn,
SupportsAbs
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator, SkipMixin
+from airflow.models import SkipMixin
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler,
return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.version_compat import BaseOperator
from airflow.utils.helpers import merge_dicts
if TYPE_CHECKING:
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
index 48d122b6696..b326387fea2 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
@@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import BaseOperator, BaseSensorOperator
+else:
+ from airflow.models import BaseOperator
+ from airflow.sensors.base import BaseSensorOperator # type:
ignore[no-redef]
+
+__all__ = [
+ "AIRFLOW_V_3_0_PLUS",
+ "BaseOperator",
+ "BaseSensorOperator",
+]
diff --git
a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
index 92b118d703a..fe01d68d2f1 100644
---
a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
+++
b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
@@ -55,11 +55,6 @@ counter = 0
@pytest.mark.backend("mysql")
class TestMySql:
- def setup_method(self):
- args = {"owner": "airflow", "start_date": DEFAULT_DATE}
- dag = DAG(TEST_DAG_ID, schedule=None, default_args=args)
- self.dag = dag
-
def teardown_method(self):
from airflow.providers.mysql.hooks.mysql import MySqlHook
@@ -77,7 +72,7 @@ class TestMySql:
"mysql-connector-python",
],
)
- def test_mysql_to_mysql(self, client):
+ def test_mysql_to_mysql(self, client, dag_maker):
class MySqlContext:
def __init__(self, client):
self.client = client
@@ -92,6 +87,25 @@ class TestMySql:
with MySqlContext(client):
sql = "SELECT * FROM connection;"
+ with dag_maker(f"TEST_DAG_ID_{client}", start_date=DEFAULT_DATE):
+ op = GenericTransfer(
+ task_id="test_m2m",
+ preoperator=[
+ "DROP TABLE IF EXISTS test_mysql_to_mysql",
+ "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE
connection",
+ ],
+ source_conn_id="airflow_db",
+ destination_conn_id="airflow_db",
+ destination_table="test_mysql_to_mysql",
+ sql=sql,
+ )
+
+ dag_maker.run_ti(op.task_id)
+
+ @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
+ def test_mysql_to_mysql_replace(self, mock_insert, dag_maker):
+ sql = "SELECT * FROM connection LIMIT 10;"
+ with dag_maker("TEST_DAG_ID", start_date=DEFAULT_DATE):
op = GenericTransfer(
task_id="test_m2m",
preoperator=[
@@ -102,27 +116,10 @@ class TestMySql:
destination_conn_id="airflow_db",
destination_table="test_mysql_to_mysql",
sql=sql,
- dag=self.dag,
+ insert_args={"replace": True},
)
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
- def test_mysql_to_mysql_replace(self, mock_insert):
- sql = "SELECT * FROM connection LIMIT 10;"
- op = GenericTransfer(
- task_id="test_m2m",
- preoperator=[
- "DROP TABLE IF EXISTS test_mysql_to_mysql",
- "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE
connection",
- ],
- source_conn_id="airflow_db",
- destination_conn_id="airflow_db",
- destination_table="test_mysql_to_mysql",
- sql=sql,
- dag=self.dag,
- insert_args={"replace": True},
- )
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ dag_maker.run_ti(op.task_id)
assert mock_insert.called
_, kwargs = mock_insert.call_args
assert "replace" in kwargs
@@ -140,7 +137,7 @@ class TestPostgres:
def test_postgres_to_postgres(self, dag_maker):
sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;"
with dag_maker(default_args={"owner": "airflow", "start_date":
DEFAULT_DATE}, serialized=True):
- op = GenericTransfer(
+ _ = GenericTransfer(
task_id="test_p2p",
preoperator=[
"DROP TABLE IF EXISTS test_postgres_to_postgres",
@@ -151,14 +148,14 @@ class TestPostgres:
destination_table="test_postgres_to_postgres",
sql=sql,
)
- dag_maker.create_dagrun()
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ dr = dag_maker.create_dagrun()
+ dag_maker.run_ti("test_p2p", dr)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):
sql = "SELECT id, conn_id, conn_type FROM connection LIMIT 10;"
with dag_maker(default_args={"owner": "airflow", "start_date":
DEFAULT_DATE}, serialized=True):
- op = GenericTransfer(
+ _ = GenericTransfer(
task_id="test_p2p",
preoperator=[
"DROP TABLE IF EXISTS test_postgres_to_postgres",
@@ -174,8 +171,8 @@ class TestPostgres:
"replace_index": "id",
},
)
- dag_maker.create_dagrun()
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ dr = dag_maker.create_dagrun()
+ dag_maker.run_ti("test_p2p", dr)
assert mock_insert.called
_, kwargs = mock_insert.call_args
assert "replace" in kwargs
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index c5760376304..b3e02f8d7f8 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -1095,6 +1095,30 @@ class TestSqlBranch:
self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
self.branch_3 = None
+ def get_ti(self, task_id, dr=None):
+ if dr is None:
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun_kwargs = {
+ "logical_date": DEFAULT_DATE,
+ "run_after": DEFAULT_DATE,
+ "triggered_by": DagRunTriggeredByType.TEST,
+ }
+ else:
+ dagrun_kwargs = {"execution_date": DEFAULT_DATE}
+ dr = self.dag.create_dagrun(
+ run_id=f"manual__{timezone.utcnow().isoformat()}",
+ run_type=DagRunType.MANUAL,
+ start_date=timezone.utcnow(),
+ state=State.RUNNING,
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ **dagrun_kwargs,
+ )
+
+ ti = dr.get_task_instance(task_id)
+ ti.task = self.dag.get_task(ti.task_id)
+
+ return ti
+
def teardown_method(self):
with create_session() as session:
session.query(DagRun).delete()
@@ -1124,7 +1148,7 @@ class TestSqlBranch:
)
with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ op.execute({})
def test_invalid_conn(self):
"""Check if BranchSQLOperator throws an exception for invalid
connection"""
@@ -1138,7 +1162,7 @@ class TestSqlBranch:
)
with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ op.execute({})
def test_invalid_follow_task_true(self):
"""Check if BranchSQLOperator throws an exception for invalid
connection"""
@@ -1152,7 +1176,7 @@ class TestSqlBranch:
)
with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ op.execute({})
def test_invalid_follow_task_false(self):
"""Check if BranchSQLOperator throws an exception for invalid
connection"""
@@ -1166,12 +1190,13 @@ class TestSqlBranch:
)
with pytest.raises(AirflowException):
- op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ op.execute({})
@pytest.mark.backend("mysql")
def test_sql_branch_operator_mysql(self, branch_op):
"""Check if BranchSQLOperator works with backend"""
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+
+ branch_op.execute({"ti": mock.MagicMock(task=branch_op)})
@pytest.mark.backend("postgres")
def test_sql_branch_operator_postgres(self):
@@ -1184,7 +1209,7 @@ class TestSqlBranch:
follow_task_ids_if_false=["branch_2"],
dag=self.dag,
)
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ self.get_ti(branch_op.task_id).run()
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
def test_branch_single_value_with_dag_run(self, mock_get_db_hook,
branch_op):
@@ -1223,8 +1248,9 @@ class TestSqlBranch:
assert exc_info.value.tasks == [("branch_2", -1)]
else:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
+
for ti in tis:
if ti.task_id == "make_choice":
assert ti.state == State.SUCCESS
@@ -1267,11 +1293,11 @@ class TestSqlBranch:
from airflow.exceptions import DownstreamTasksSkipped
with pytest.raises(DownstreamTasksSkipped) as exc_info:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.execute({})
assert exc_info.value.tasks == [("branch_2", -1)]
else:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
@@ -1315,11 +1341,12 @@ class TestSqlBranch:
from airflow.exceptions import DownstreamTasksSkipped
with pytest.raises(DownstreamTasksSkipped) as exc_info:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.execute({})
assert exc_info.value.tasks == [("branch_1", -1)]
else:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
+
for ti in tis:
if ti.task_id == "make_choice":
assert ti.state == State.SUCCESS
@@ -1375,7 +1402,7 @@ class TestSqlBranch:
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
assert exc_info.value.tasks == [("branch_3", -1)]
else:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
@@ -1416,7 +1443,7 @@ class TestSqlBranch:
mock_get_records.return_value = ["Invalid Value"]
with pytest.raises(AirflowException):
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.execute({})
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
def test_with_skip_in_branch_downstream_dependencies(self,
mock_get_db_hook, branch_op):
@@ -1447,7 +1474,7 @@ class TestSqlBranch:
for true_value in SUPPORTED_TRUE_VALUES:
mock_get_records.return_value = [true_value]
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
for ti in tis:
@@ -1493,7 +1520,7 @@ class TestSqlBranch:
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
assert exc_info.value.tasks == [("branch_1", -1)]
else:
- branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ self.get_ti(branch_op.task_id, dr).run()
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
diff --git
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index b61b699774e..9f52f80a4eb 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -58,23 +58,20 @@ SINGLE_STMT = "select i from user_test order by i;"
@pytest.mark.db_test
class TestSnowflakeOperator:
- def setup_method(self):
- args = {"owner": "airflow", "start_date": DEFAULT_DATE}
- dag = DAG(TEST_DAG_ID, schedule=None, default_args=args)
- self.dag = dag
-
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
- def test_snowflake_operator(self, mock_get_db_hook):
+ def test_snowflake_operator(self, mock_get_db_hook, dag_maker):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
- operator = SQLExecuteQueryOperator(
- task_id="basic_snowflake", sql=sql, dag=self.dag,
do_xcom_push=False, conn_id="snowflake_default"
- )
+
+ with dag_maker(TEST_DAG_ID):
+ operator = SQLExecuteQueryOperator(
+ task_id="basic_snowflake", sql=sql, do_xcom_push=False,
conn_id="snowflake_default"
+ )
# do_xcom_push=False because otherwise the XCom test will fail due to
the mocking (it actually works)
- operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ dag_maker.run_ti(operator.task_id)
class TestSnowflakeOperatorForParams: