This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 67613f0c38b Replace models.BaseOperator to Task SDK one for Common 
Providers (#52443)
67613f0c38b is described below

commit 67613f0c38be5b98f7d12d88a6b751d6e27590a6
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jun 30 03:06:20 2025 +0530

    Replace models.BaseOperator to Task SDK one for Common Providers (#52443)
    
    Part of https://github.com/apache/airflow/issues/52378
---
 .../providers/common/compat/standard/operators.py  |  3 +-
 .../providers/common/compat/version_compat.py      | 12 +++++
 .../providers/common/io/operators/file_transfer.py |  4 +-
 .../airflow/providers/common/io/version_compat.py  | 10 ++++
 .../common/sql/operators/generic_transfer.py       |  4 +-
 .../airflow/providers/common/sql/operators/sql.py  |  3 +-
 .../airflow/providers/common/sql/version_compat.py | 12 +++++
 .../common/sql/operators/test_generic_transfer.py  | 59 ++++++++++------------
 .../tests/unit/common/sql/operators/test_sql.py    | 57 +++++++++++++++------
 .../unit/snowflake/operators/test_snowflake.py     | 17 +++----
 10 files changed, 118 insertions(+), 63 deletions(-)

diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
index c8838fe3c21..b3d35f1aa14 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
@@ -46,5 +46,6 @@ else:
     except (ImportError, ModuleNotFoundError):
         from airflow.providers.standard.operators.python import 
get_current_context
 
+from airflow.providers.common.compat.version_compat import BaseOperator
 
-__all__ = ["PythonOperator", "_SERIALIZERS", "ShortCircuitOperator", 
"get_current_context"]
+__all__ = ["BaseOperator", "PythonOperator", "_SERIALIZERS", 
"ShortCircuitOperator", "get_current_context"]
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py 
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
index 48d122b6696..02d0f1ac162 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/version_compat.py
@@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 
 AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import BaseOperator
+else:
+    from airflow.models import BaseOperator
+
+__all__ = [
+    "AIRFLOW_V_3_0_PLUS",
+    "AIRFLOW_V_3_1_PLUS",
+    "BaseOperator",
+]
diff --git 
a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
 
b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
index 0faec858d1c..00b23bde989 100644
--- 
a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
+++ 
b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 from collections.abc import Sequence
 from typing import TYPE_CHECKING
 
-from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS, 
BaseOperator
 
 if TYPE_CHECKING:
     from airflow.providers.openlineage.extractors import OperatorLineage
@@ -28,10 +28,8 @@ if TYPE_CHECKING:
 
 if AIRFLOW_V_3_0_PLUS:
     from airflow.sdk import ObjectStoragePath
-    from airflow.sdk.bases.operator import BaseOperator
 else:
     from airflow.io.path import ObjectStoragePath  # type: ignore[no-redef]
-    from airflow.models import BaseOperator  # type: ignore[no-redef]
 
 
 class FileTransferOperator(BaseOperator):
diff --git 
a/providers/common/io/src/airflow/providers/common/io/version_compat.py 
b/providers/common/io/src/airflow/providers/common/io/version_compat.py
index 48d122b6696..e7a259afb35 100644
--- a/providers/common/io/src/airflow/providers/common/io/version_compat.py
+++ b/providers/common/io/src/airflow/providers/common/io/version_compat.py
@@ -33,3 +33,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 
 AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import BaseOperator
+else:
+    from airflow.models import BaseOperator
+
+__all__ = [
+    "AIRFLOW_V_3_0_PLUS",
+    "BaseOperator",
+]
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
 
b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
index c7839b28aec..4e3b0a87ce0 100644
--- 
a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
+++ 
b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py
@@ -23,9 +23,9 @@ from typing import TYPE_CHECKING, Any
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger
+from airflow.providers.common.sql.version_compat import BaseOperator
 
 if TYPE_CHECKING:
     import jinja2
@@ -192,7 +192,7 @@ class GenericTransfer(BaseOperator):
                 )
 
                 self.log.info("Offset increased to %d", offset)
-                self.xcom_push(context=context, key="offset", value=offset)
+                context["ti"].xcom_push(key="offset", value=offset)
 
                 self.log.info("Inserting %d rows into %s", len(results), 
self.destination_conn_id)
                 self.destination_hook.insert_rows(
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py 
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index af04349532a..250a249d5af 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -25,9 +25,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, 
SupportsAbs
 
 from airflow.exceptions import AirflowException, AirflowFailException
 from airflow.hooks.base import BaseHook
-from airflow.models import BaseOperator, SkipMixin
+from airflow.models import SkipMixin
 from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, 
return_single_query_results
 from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.version_compat import BaseOperator
 from airflow.utils.helpers import merge_dicts
 
 if TYPE_CHECKING:
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py 
b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
index 48d122b6696..b326387fea2 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/version_compat.py
@@ -33,3 +33,15 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 
 AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import BaseOperator, BaseSensorOperator
+else:
+    from airflow.models import BaseOperator
+    from airflow.sensors.base import BaseSensorOperator  # type: 
ignore[no-redef]
+
+__all__ = [
+    "AIRFLOW_V_3_0_PLUS",
+    "BaseOperator",
+    "BaseSensorOperator",
+]
diff --git 
a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py 
b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
index 92b118d703a..fe01d68d2f1 100644
--- 
a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
+++ 
b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py
@@ -55,11 +55,6 @@ counter = 0
 
 @pytest.mark.backend("mysql")
 class TestMySql:
-    def setup_method(self):
-        args = {"owner": "airflow", "start_date": DEFAULT_DATE}
-        dag = DAG(TEST_DAG_ID, schedule=None, default_args=args)
-        self.dag = dag
-
     def teardown_method(self):
         from airflow.providers.mysql.hooks.mysql import MySqlHook
 
@@ -77,7 +72,7 @@ class TestMySql:
             "mysql-connector-python",
         ],
     )
-    def test_mysql_to_mysql(self, client):
+    def test_mysql_to_mysql(self, client, dag_maker):
         class MySqlContext:
             def __init__(self, client):
                 self.client = client
@@ -92,6 +87,25 @@ class TestMySql:
 
         with MySqlContext(client):
             sql = "SELECT * FROM connection;"
+            with dag_maker(f"TEST_DAG_ID_{client}", start_date=DEFAULT_DATE):
+                op = GenericTransfer(
+                    task_id="test_m2m",
+                    preoperator=[
+                        "DROP TABLE IF EXISTS test_mysql_to_mysql",
+                        "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE 
connection",
+                    ],
+                    source_conn_id="airflow_db",
+                    destination_conn_id="airflow_db",
+                    destination_table="test_mysql_to_mysql",
+                    sql=sql,
+                )
+
+            dag_maker.run_ti(op.task_id)
+
+    @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
+    def test_mysql_to_mysql_replace(self, mock_insert, dag_maker):
+        sql = "SELECT * FROM connection LIMIT 10;"
+        with dag_maker("TEST_DAG_ID", start_date=DEFAULT_DATE):
             op = GenericTransfer(
                 task_id="test_m2m",
                 preoperator=[
@@ -102,27 +116,10 @@ class TestMySql:
                 destination_conn_id="airflow_db",
                 destination_table="test_mysql_to_mysql",
                 sql=sql,
-                dag=self.dag,
+                insert_args={"replace": True},
             )
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
-    @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
-    def test_mysql_to_mysql_replace(self, mock_insert):
-        sql = "SELECT * FROM connection LIMIT 10;"
-        op = GenericTransfer(
-            task_id="test_m2m",
-            preoperator=[
-                "DROP TABLE IF EXISTS test_mysql_to_mysql",
-                "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE 
connection",
-            ],
-            source_conn_id="airflow_db",
-            destination_conn_id="airflow_db",
-            destination_table="test_mysql_to_mysql",
-            sql=sql,
-            dag=self.dag,
-            insert_args={"replace": True},
-        )
-        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        dag_maker.run_ti(op.task_id)
         assert mock_insert.called
         _, kwargs = mock_insert.call_args
         assert "replace" in kwargs
@@ -140,7 +137,7 @@ class TestPostgres:
     def test_postgres_to_postgres(self, dag_maker):
         sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES LIMIT 100;"
         with dag_maker(default_args={"owner": "airflow", "start_date": 
DEFAULT_DATE}, serialized=True):
-            op = GenericTransfer(
+            _ = GenericTransfer(
                 task_id="test_p2p",
                 preoperator=[
                     "DROP TABLE IF EXISTS test_postgres_to_postgres",
@@ -151,14 +148,14 @@ class TestPostgres:
                 destination_table="test_postgres_to_postgres",
                 sql=sql,
             )
-        dag_maker.create_dagrun()
-        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        dr = dag_maker.create_dagrun()
+        dag_maker.run_ti("test_p2p", dr)
 
     @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows")
     def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):
         sql = "SELECT id, conn_id, conn_type FROM connection LIMIT 10;"
         with dag_maker(default_args={"owner": "airflow", "start_date": 
DEFAULT_DATE}, serialized=True):
-            op = GenericTransfer(
+            _ = GenericTransfer(
                 task_id="test_p2p",
                 preoperator=[
                     "DROP TABLE IF EXISTS test_postgres_to_postgres",
@@ -174,8 +171,8 @@ class TestPostgres:
                     "replace_index": "id",
                 },
             )
-        dag_maker.create_dagrun()
-        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        dr = dag_maker.create_dagrun()
+        dag_maker.run_ti("test_p2p", dr)
         assert mock_insert.called
         _, kwargs = mock_insert.call_args
         assert "replace" in kwargs
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py 
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index c5760376304..b3e02f8d7f8 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -1095,6 +1095,30 @@ class TestSqlBranch:
         self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
         self.branch_3 = None
 
+    def get_ti(self, task_id, dr=None):
+        if dr is None:
+            if AIRFLOW_V_3_0_PLUS:
+                dagrun_kwargs = {
+                    "logical_date": DEFAULT_DATE,
+                    "run_after": DEFAULT_DATE,
+                    "triggered_by": DagRunTriggeredByType.TEST,
+                }
+            else:
+                dagrun_kwargs = {"execution_date": DEFAULT_DATE}
+            dr = self.dag.create_dagrun(
+                run_id=f"manual__{timezone.utcnow().isoformat()}",
+                run_type=DagRunType.MANUAL,
+                start_date=timezone.utcnow(),
+                state=State.RUNNING,
+                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+                **dagrun_kwargs,
+            )
+
+        ti = dr.get_task_instance(task_id)
+        ti.task = self.dag.get_task(ti.task_id)
+
+        return ti
+
     def teardown_method(self):
         with create_session() as session:
             session.query(DagRun).delete()
@@ -1124,7 +1148,7 @@ class TestSqlBranch:
         )
 
         with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+            op.execute({})
 
     def test_invalid_conn(self):
         """Check if BranchSQLOperator throws an exception for invalid 
connection"""
@@ -1138,7 +1162,7 @@ class TestSqlBranch:
         )
 
         with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+            op.execute({})
 
     def test_invalid_follow_task_true(self):
         """Check if BranchSQLOperator throws an exception for invalid 
connection"""
@@ -1152,7 +1176,7 @@ class TestSqlBranch:
         )
 
         with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+            op.execute({})
 
     def test_invalid_follow_task_false(self):
         """Check if BranchSQLOperator throws an exception for invalid 
connection"""
@@ -1166,12 +1190,13 @@ class TestSqlBranch:
         )
 
         with pytest.raises(AirflowException):
-            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+            op.execute({})
 
     @pytest.mark.backend("mysql")
     def test_sql_branch_operator_mysql(self, branch_op):
         """Check if BranchSQLOperator works with backend"""
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+        branch_op.execute({"ti": mock.MagicMock(task=branch_op)})
 
     @pytest.mark.backend("postgres")
     def test_sql_branch_operator_postgres(self):
@@ -1184,7 +1209,7 @@ class TestSqlBranch:
             follow_task_ids_if_false=["branch_2"],
             dag=self.dag,
         )
-        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        self.get_ti(branch_op.task_id).run()
 
     
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
     def test_branch_single_value_with_dag_run(self, mock_get_db_hook, 
branch_op):
@@ -1223,8 +1248,9 @@ class TestSqlBranch:
 
             assert exc_info.value.tasks == [("branch_2", -1)]
         else:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
             tis = dr.get_task_instances()
+
             for ti in tis:
                 if ti.task_id == "make_choice":
                     assert ti.state == State.SUCCESS
@@ -1267,11 +1293,11 @@ class TestSqlBranch:
             from airflow.exceptions import DownstreamTasksSkipped
 
             with pytest.raises(DownstreamTasksSkipped) as exc_info:
-                branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+                branch_op.execute({})
 
             assert exc_info.value.tasks == [("branch_2", -1)]
         else:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
             tis = dr.get_task_instances()
             for ti in tis:
                 if ti.task_id == "make_choice":
@@ -1315,11 +1341,12 @@ class TestSqlBranch:
             from airflow.exceptions import DownstreamTasksSkipped
 
             with pytest.raises(DownstreamTasksSkipped) as exc_info:
-                branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+                branch_op.execute({})
             assert exc_info.value.tasks == [("branch_1", -1)]
         else:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
             tis = dr.get_task_instances()
+
             for ti in tis:
                 if ti.task_id == "make_choice":
                     assert ti.state == State.SUCCESS
@@ -1375,7 +1402,7 @@ class TestSqlBranch:
                 branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
             assert exc_info.value.tasks == [("branch_3", -1)]
         else:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
             tis = dr.get_task_instances()
             for ti in tis:
                 if ti.task_id == "make_choice":
@@ -1416,7 +1443,7 @@ class TestSqlBranch:
         mock_get_records.return_value = ["Invalid Value"]
 
         with pytest.raises(AirflowException):
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            branch_op.execute({})
 
     
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
     def test_with_skip_in_branch_downstream_dependencies(self, 
mock_get_db_hook, branch_op):
@@ -1447,7 +1474,7 @@ class TestSqlBranch:
         for true_value in SUPPORTED_TRUE_VALUES:
             mock_get_records.return_value = [true_value]
 
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
 
             tis = dr.get_task_instances()
             for ti in tis:
@@ -1493,7 +1520,7 @@ class TestSqlBranch:
                 branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
             assert exc_info.value.tasks == [("branch_1", -1)]
         else:
-            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            self.get_ti(branch_op.task_id, dr).run()
             tis = dr.get_task_instances()
             for ti in tis:
                 if ti.task_id == "make_choice":
diff --git 
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index b61b699774e..9f52f80a4eb 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -58,23 +58,20 @@ SINGLE_STMT = "select i from user_test order by i;"
 
 @pytest.mark.db_test
 class TestSnowflakeOperator:
-    def setup_method(self):
-        args = {"owner": "airflow", "start_date": DEFAULT_DATE}
-        dag = DAG(TEST_DAG_ID, schedule=None, default_args=args)
-        self.dag = dag
-
     
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
-    def test_snowflake_operator(self, mock_get_db_hook):
+    def test_snowflake_operator(self, mock_get_db_hook, dag_maker):
         sql = """
         CREATE TABLE IF NOT EXISTS test_airflow (
             dummy VARCHAR(50)
         );
         """
-        operator = SQLExecuteQueryOperator(
-            task_id="basic_snowflake", sql=sql, dag=self.dag, 
do_xcom_push=False, conn_id="snowflake_default"
-        )
+
+        with dag_maker(TEST_DAG_ID):
+            operator = SQLExecuteQueryOperator(
+                task_id="basic_snowflake", sql=sql, do_xcom_push=False, 
conn_id="snowflake_default"
+            )
         # do_xcom_push=False because otherwise the XCom test will fail due to 
the mocking (it actually works)
-        operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        dag_maker.run_ti(operator.task_id)
 
 
 class TestSnowflakeOperatorForParams:

Reply via email to