This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b100b5ef59c Add direct GCS export to DatabricksSqlOperator with 
Parquet/Avro support #55128 (#60543)
b100b5ef59c is described below

commit b100b5ef59c9344531f213f77d107d613ed4381d
Author: Haseeb Malik <[email protected]>
AuthorDate: Sat Jan 17 05:52:12 2026 -0500

    Add direct GCS export to DatabricksSqlOperator with Parquet/Avro support 
#55128 (#60543)
---
 dev/breeze/tests/test_selective_checks.py          |   8 +-
 providers/databricks/docs/index.rst                |   1 +
 providers/databricks/pyproject.toml                |   7 +
 .../databricks/operators/databricks_sql.py         | 193 +++++++++++++++++----
 .../databricks/operators/test_databricks_sql.py    | 139 +++++++++++++++
 5 files changed, 315 insertions(+), 33 deletions(-)

diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index 66a356dda4d..1271e73384f 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1948,7 +1948,7 @@ def test_expected_output_push(
             ),
             {
                 "selected-providers-list-as-string": "amazon apache.beam 
apache.cassandra apache.kafka "
-                "cncf.kubernetes common.compat common.sql "
+                "cncf.kubernetes common.compat common.sql databricks "
                 "facebook google hashicorp http microsoft.azure 
microsoft.mssql mysql "
                 "openlineage oracle postgres presto salesforce samba sftp ssh 
standard trino",
                 "all-python-versions": 
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
@@ -1960,7 +1960,7 @@ def test_expected_output_push(
                 "skip-providers-tests": "false",
                 "docs-build": "true",
                 "docs-list-as-string": "apache-airflow helm-chart amazon 
apache.beam apache.cassandra "
-                "apache.kafka cncf.kubernetes common.compat common.sql 
facebook google hashicorp http microsoft.azure "
+                "apache.kafka cncf.kubernetes common.compat common.sql 
databricks facebook google hashicorp http microsoft.azure "
                 "microsoft.mssql mysql openlineage oracle postgres "
                 "presto salesforce samba sftp ssh standard trino",
                 "skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
@@ -1974,7 +1974,7 @@ def test_expected_output_push(
                         {
                             "description": "amazon...standard",
                             "test_types": "Providers[amazon] 
Providers[apache.beam,apache.cassandra,"
-                            
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
+                            
"apache.kafka,cncf.kubernetes,common.compat,common.sql,databricks,facebook,"
                             
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
                             
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
                             "Providers[google] "
@@ -2245,7 +2245,7 @@ def test_upgrade_to_newer_dependencies(
             ("providers/google/docs/some_file.rst",),
             {
                 "docs-list-as-string": "amazon apache.beam apache.cassandra 
apache.kafka "
-                "cncf.kubernetes common.compat common.sql facebook google 
hashicorp http "
+                "cncf.kubernetes common.compat common.sql databricks facebook 
google hashicorp http "
                 "microsoft.azure microsoft.mssql mysql openlineage oracle "
                 "postgres presto salesforce samba sftp ssh standard trino",
             },
diff --git a/providers/databricks/docs/index.rst 
b/providers/databricks/docs/index.rst
index 4df2fce6ad5..f366162012b 100644
--- a/providers/databricks/docs/index.rst
+++ b/providers/databricks/docs/index.rst
@@ -132,6 +132,7 @@ Dependent package
 
==================================================================================================================
  =================
 `apache-airflow-providers-common-compat 
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_  
``common.compat``
 `apache-airflow-providers-common-sql 
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_        
``common.sql``
+`apache-airflow-providers-google 
<https://airflow.apache.org/docs/apache-airflow-providers-google>`_             
   ``google``
 `apache-airflow-providers-openlineage 
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_      
``openlineage``
 
==================================================================================================================
  =================
 
diff --git a/providers/databricks/pyproject.toml 
b/providers/databricks/pyproject.toml
index f9a777eb085..42837674e09 100644
--- a/providers/databricks/pyproject.toml
+++ b/providers/databricks/pyproject.toml
@@ -93,6 +93,12 @@ dependencies = [
 "sqlalchemy" = [
     "databricks-sqlalchemy>=1.0.2",
 ]
+"google" = [
+    "apache-airflow-providers-google>=10.24.0"
+]
+"avro" = [
+    "fastavro>=1.9.0"
+]
 
 [dependency-groups]
 dev = [
@@ -101,6 +107,7 @@ dev = [
     "apache-airflow-devel-common",
     "apache-airflow-providers-common-compat",
     "apache-airflow-providers-common-sql",
+    "apache-airflow-providers-google",
     "apache-airflow-providers-openlineage",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
     # Need to exclude 1.3.0 due to missing aarch64 binaries, fixed with 1.3.1++
diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index 999b173540c..840fa7bc8b8 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -21,13 +21,20 @@ from __future__ import annotations
 
 import csv
 import json
+import os
 from collections.abc import Sequence
 from functools import cached_property
+from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING, Any, ClassVar
+from urllib.parse import urlparse
 
 from databricks.sql.utils import ParamEscaper
 
-from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
+from airflow.providers.common.compat.sdk import (
+    AirflowException,
+    AirflowOptionalProviderFeatureException,
+    BaseOperator,
+)
 from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
 
@@ -62,13 +69,27 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
     :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+ (templated)
     :param schema: An optional initial schema to use. Requires DBR version 
9.0+ (templated)
     :param output_path: optional string specifying the file to which write 
selected data. (templated)
-    :param output_format: format of output data if ``output_path` is specified.
-        Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
+        Supports local file paths and GCS URIs (e.g., 
``gs://bucket/path/file.parquet``).
+        When using GCS URIs, requires the ``apache-airflow-providers-google`` 
package.
+    :param output_format: format of output data if ``output_path`` is 
specified.
+        Possible values are ``csv``, ``json``, ``jsonl``, ``parquet``, 
``avro``. Default is ``csv``.
     :param csv_params: parameters that will be passed to the 
``csv.DictWriter`` class used to write CSV data.
+    :param gcp_conn_id: The connection ID to use for connecting to Google 
Cloud when using GCS output path.
+        Default is ``google_cloud_default``.
+    :param gcs_impersonation_chain: Optional service account to impersonate 
using short-term
+        credentials for GCS upload, or chained list of accounts required to 
get the access_token
+        of the last account in the list, which will be impersonated in the 
request. (templated)
     """
 
     template_fields: Sequence[str] = tuple(
-        {"_output_path", "schema", "catalog", "http_headers", 
"databricks_conn_id"}
+        {
+            "_output_path",
+            "schema",
+            "catalog",
+            "http_headers",
+            "databricks_conn_id",
+            "_gcs_impersonation_chain",
+        }
         | set(SQLExecuteQueryOperator.template_fields)
     )
 
@@ -90,6 +111,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
         output_format: str = "csv",
         csv_params: dict[str, Any] | None = None,
         client_parameters: dict[str, Any] | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        gcs_impersonation_chain: str | Sequence[str] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(conn_id=databricks_conn_id, **kwargs)
@@ -105,6 +128,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
         self.http_headers = http_headers
         self.catalog = catalog
         self.schema = schema
+        self._gcp_conn_id = gcp_conn_id
+        self._gcs_impersonation_chain = gcs_impersonation_chain
 
     @cached_property
     def _hook(self) -> DatabricksSqlHook:
@@ -127,41 +152,151 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
     def _should_run_output_processing(self) -> bool:
         return self.do_xcom_push or bool(self._output_path)
 
+    @property
+    def _is_gcs_output(self) -> bool:
+        """Check if the output path is a GCS URI."""
+        return self._output_path.startswith("gs://") if self._output_path else 
False
+
+    def _parse_gcs_path(self, path: str) -> tuple[str, str]:
+        """Parse a GCS URI into bucket and object name."""
+        parsed = urlparse(path)
+        bucket = parsed.netloc
+        object_name = parsed.path.lstrip("/")
+        return bucket, object_name
+
+    def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
+        """Upload a local file to GCS."""
+        try:
+            from airflow.providers.google.cloud.hooks.gcs import GCSHook
+        except ImportError:
+            raise AirflowOptionalProviderFeatureException(
+                "The 'apache-airflow-providers-google' package is required for 
GCS output. "
+                "Install it with: pip install apache-airflow-providers-google"
+            )
+
+        bucket, object_name = self._parse_gcs_path(gcs_path)
+        hook = GCSHook(
+            gcp_conn_id=self._gcp_conn_id,
+            impersonation_chain=self._gcs_impersonation_chain,
+        )
+        hook.upload(
+            bucket_name=bucket,
+            object_name=object_name,
+            filename=local_path,
+        )
+        self.log.info("Uploaded output to %s", gcs_path)
+
+    def _write_parquet(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to a Parquet file."""
+        import pyarrow as pa
+        import pyarrow.parquet as pq
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        table = pa.Table.from_pydict(data)
+        pq.write_table(table, file_path)
+
+    def _write_avro(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to an Avro file using fastavro."""
+        try:
+            from fastavro import writer
+        except ImportError:
+            raise AirflowOptionalProviderFeatureException(
+                "The 'fastavro' package is required for Avro output. Install 
it with: pip install fastavro"
+            )
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        schema_fields = []
+        for name in field_names:
+            sample_val = next(
+                (data[name][i] for i in range(len(data[name])) if 
data[name][i] is not None), None
+            )
+            if sample_val is None:
+                avro_type = ["null", "string"]
+            elif isinstance(sample_val, bool):
+                avro_type = ["null", "boolean"]
+            elif isinstance(sample_val, int):
+                avro_type = ["null", "long"]
+            elif isinstance(sample_val, float):
+                avro_type = ["null", "double"]
+            else:
+                avro_type = ["null", "string"]
+            schema_fields.append({"name": name, "type": avro_type})
+
+        avro_schema = {
+            "type": "record",
+            "name": "QueryResult",
+            "fields": schema_fields,
+        }
+
+        records = [row._asdict() for row in rows]
+        with open(file_path, "wb") as f:
+            writer(f, avro_schema, records)
+
     def _process_output(self, results: list[Any], descriptions: 
list[Sequence[Sequence] | None]) -> list[Any]:
         if not self._output_path:
             return list(zip(descriptions, results))
         if not self._output_format:
             raise AirflowException("Output format should be specified!")
-        # Output to a file only the result of last query
+
         last_description = descriptions[-1]
         last_results = results[-1]
         if last_description is None:
-            raise AirflowException("There is missing description present for 
the output file. .")
+            raise AirflowException("There is missing description present for 
the output file.")
         field_names = [field[0] for field in last_description]
-        if self._output_format.lower() == "csv":
-            with open(self._output_path, "w", newline="") as file:
-                if self._csv_params:
-                    csv_params = self._csv_params
-                else:
-                    csv_params = {}
-                write_header = csv_params.get("header", True)
-                if "header" in csv_params:
-                    del csv_params["header"]
-                writer = csv.DictWriter(file, fieldnames=field_names, 
**csv_params)
-                if write_header:
-                    writer.writeheader()
-                for row in last_results:
-                    writer.writerow(row._asdict())
-        elif self._output_format.lower() == "json":
-            with open(self._output_path, "w") as file:
-                file.write(json.dumps([row._asdict() for row in last_results]))
-        elif self._output_format.lower() == "jsonl":
-            with open(self._output_path, "w") as file:
-                for row in last_results:
-                    file.write(json.dumps(row._asdict()))
-                    file.write("\n")
+
+        if self._is_gcs_output:
+            suffix = f".{self._output_format.lower()}"
+            tmp_file = NamedTemporaryFile(mode="w", suffix=suffix, 
delete=False, newline="")
+            local_path = tmp_file.name
+            tmp_file.close()
         else:
-            raise AirflowException(f"Unsupported output format: 
'{self._output_format}'")
+            local_path = self._output_path
+
+        try:
+            output_format = self._output_format.lower()
+            if output_format == "csv":
+                with open(local_path, "w", newline="") as file:
+                    if self._csv_params:
+                        csv_params = self._csv_params.copy()
+                    else:
+                        csv_params = {}
+                    write_header = csv_params.pop("header", True)
+                    writer = csv.DictWriter(file, fieldnames=field_names, 
**csv_params)
+                    if write_header:
+                        writer.writeheader()
+                    for row in last_results:
+                        writer.writerow(row._asdict())
+            elif output_format == "json":
+                with open(local_path, "w") as file:
+                    file.write(json.dumps([row._asdict() for row in 
last_results]))
+            elif output_format == "jsonl":
+                with open(local_path, "w") as file:
+                    for row in last_results:
+                        file.write(json.dumps(row._asdict()))
+                        file.write("\n")
+            elif output_format == "parquet":
+                self._write_parquet(local_path, field_names, last_results)
+            elif output_format == "avro":
+                self._write_avro(local_path, field_names, last_results)
+            else:
+                raise ValueError(f"Unsupported output format: 
'{self._output_format}'")
+
+            if self._is_gcs_output:
+                self._upload_to_gcs(local_path, self._output_path)
+        finally:
+            if self._is_gcs_output and os.path.exists(local_path):
+                os.unlink(local_path)
+
         return list(zip(descriptions, results))
 
 
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
index 58cc0e59607..e216c56bea2 100644
--- 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
+++ 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
@@ -314,3 +314,142 @@ def test_hook_is_cached():
     hook = op.get_db_hook()
     hook2 = op.get_db_hook()
     assert hook is hook2
+
+
+def test_exec_write_parquet_file(tmp_path):
+    """Test writing output to Parquet format."""
+    with 
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as db_mock_class:
+        path = tmp_path / "testfile.parquet"
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="select * from dummy",
+            output_path=os.fspath(path),
+            output_format="parquet",
+            return_last=True,
+            do_xcom_push=True,
+            split_statements=False,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.run.return_value = [SerializableRow(1, "value1"), 
SerializableRow(2, "value2")]
+        db_mock.descriptions = [[("id",), ("value",)]]
+
+        op.execute(None)
+
+        import pyarrow.parquet as pq
+
+        table = pq.read_table(path)
+        assert table.num_rows == 2
+        assert table.column_names == ["id", "value"]
+        assert table.column("id").to_pylist() == [1, 2]
+        assert table.column("value").to_pylist() == ["value1", "value2"]
+
+
+def test_exec_write_avro_file_with_fastavro(tmp_path):
+    """Test writing output to Avro format using fastavro."""
+    pytest.importorskip("fastavro")
+    with 
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as db_mock_class:
+        path = tmp_path / "testfile.avro"
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="select * from dummy",
+            output_path=os.fspath(path),
+            output_format="avro",
+            return_last=True,
+            do_xcom_push=True,
+            split_statements=False,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.run.return_value = [SerializableRow(1, "value1"), 
SerializableRow(2, "value2")]
+        db_mock.descriptions = [[("id",), ("value",)]]
+
+        op.execute(None)
+
+        from fastavro import reader
+
+        with open(path, "rb") as f:
+            records = list(reader(f))
+        assert len(records) == 2
+        assert records[0] == {"id": 1, "value": "value1"}
+        assert records[1] == {"id": 2, "value": "value2"}
+
+
+def test_exec_write_gcs_output(tmp_path):
+    """Test writing output to GCS."""
+    with (
+        
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as db_mock_class,
+        patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as 
gcs_mock_class,
+    ):
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="select * from dummy",
+            output_path="gs://my-bucket/path/to/output.csv",
+            output_format="csv",
+            return_last=True,
+            do_xcom_push=True,
+            split_statements=False,
+            gcp_conn_id="my_gcp_conn",
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.run.return_value = [SerializableRow(1, "value1"), 
SerializableRow(2, "value2")]
+        db_mock.descriptions = [[("id",), ("value",)]]
+
+        op.execute(None)
+
+        gcs_mock_class.assert_called_once_with(
+            gcp_conn_id="my_gcp_conn",
+            impersonation_chain=None,
+        )
+        gcs_mock_class.return_value.upload.assert_called_once()
+        call_kwargs = gcs_mock_class.return_value.upload.call_args[1]
+        assert call_kwargs["bucket_name"] == "my-bucket"
+        assert call_kwargs["object_name"] == "path/to/output.csv"
+
+
+def test_exec_write_gcs_parquet_output(tmp_path):
+    """Test writing Parquet output to GCS."""
+    with (
+        
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as db_mock_class,
+        patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as 
gcs_mock_class,
+    ):
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="select * from dummy",
+            output_path="gs://my-bucket/data/results.parquet",
+            output_format="parquet",
+            return_last=True,
+            do_xcom_push=True,
+            split_statements=False,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.run.return_value = [SerializableRow(1, "value1"), 
SerializableRow(2, "value2")]
+        db_mock.descriptions = [[("id",), ("value",)]]
+
+        op.execute(None)
+
+        gcs_mock_class.return_value.upload.assert_called_once()
+        call_kwargs = gcs_mock_class.return_value.upload.call_args[1]
+        assert call_kwargs["bucket_name"] == "my-bucket"
+        assert call_kwargs["object_name"] == "data/results.parquet"
+
+
+def test_is_gcs_output():
+    """Test _is_gcs_output property."""
+    op_gcs = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1", 
output_path="gs://bucket/path")
+    assert op_gcs._is_gcs_output is True
+
+    op_local = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1", 
output_path="/local/path")
+    assert op_local._is_gcs_output is False
+
+    op_s3 = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1", 
output_path="s3://bucket/path")
+    assert op_s3._is_gcs_output is False
+
+    op_none = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+    assert op_none._is_gcs_output is False
+
+
+def test_parse_gcs_path():
+    """Test _parse_gcs_path method."""
+    op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+    bucket, object_name = 
op._parse_gcs_path("gs://my-bucket/path/to/file.parquet")
+    assert bucket == "my-bucket"
+    assert object_name == "path/to/file.parquet"

Reply via email to