This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b100b5ef59c Add direct GCS export to DatabricksSqlOperator with
Parquet/Avro support #55128 (#60543)
b100b5ef59c is described below
commit b100b5ef59c9344531f213f77d107d613ed4381d
Author: Haseeb Malik <[email protected]>
AuthorDate: Sat Jan 17 05:52:12 2026 -0500
Add direct GCS export to DatabricksSqlOperator with Parquet/Avro support
#55128 (#60543)
---
dev/breeze/tests/test_selective_checks.py | 8 +-
providers/databricks/docs/index.rst | 1 +
providers/databricks/pyproject.toml | 7 +
.../databricks/operators/databricks_sql.py | 193 +++++++++++++++++----
.../databricks/operators/test_databricks_sql.py | 139 +++++++++++++++
5 files changed, 315 insertions(+), 33 deletions(-)
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index 66a356dda4d..1271e73384f 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1948,7 +1948,7 @@ def test_expected_output_push(
),
{
"selected-providers-list-as-string": "amazon apache.beam
apache.cassandra apache.kafka "
- "cncf.kubernetes common.compat common.sql "
+ "cncf.kubernetes common.compat common.sql databricks "
"facebook google hashicorp http microsoft.azure
microsoft.mssql mysql "
"openlineage oracle postgres presto salesforce samba sftp ssh
standard trino",
"all-python-versions":
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
@@ -1960,7 +1960,7 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"docs-build": "true",
"docs-list-as-string": "apache-airflow helm-chart amazon
apache.beam apache.cassandra "
- "apache.kafka cncf.kubernetes common.compat common.sql
facebook google hashicorp http microsoft.azure "
+ "apache.kafka cncf.kubernetes common.compat common.sql
databricks facebook google hashicorp http microsoft.azure "
"microsoft.mssql mysql openlineage oracle postgres "
"presto salesforce samba sftp ssh standard trino",
"skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
@@ -1974,7 +1974,7 @@ def test_expected_output_push(
{
"description": "amazon...standard",
"test_types": "Providers[amazon]
Providers[apache.beam,apache.cassandra,"
-
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
+
"apache.kafka,cncf.kubernetes,common.compat,common.sql,databricks,facebook,"
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
"Providers[google] "
@@ -2245,7 +2245,7 @@ def test_upgrade_to_newer_dependencies(
("providers/google/docs/some_file.rst",),
{
"docs-list-as-string": "amazon apache.beam apache.cassandra
apache.kafka "
- "cncf.kubernetes common.compat common.sql facebook google
hashicorp http "
+ "cncf.kubernetes common.compat common.sql databricks facebook
google hashicorp http "
"microsoft.azure microsoft.mssql mysql openlineage oracle "
"postgres presto salesforce samba sftp ssh standard trino",
},
diff --git a/providers/databricks/docs/index.rst
b/providers/databricks/docs/index.rst
index 4df2fce6ad5..f366162012b 100644
--- a/providers/databricks/docs/index.rst
+++ b/providers/databricks/docs/index.rst
@@ -132,6 +132,7 @@ Dependent package
==================================================================================================================
=================
`apache-airflow-providers-common-compat
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_
``common.compat``
`apache-airflow-providers-common-sql
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_
``common.sql``
+`apache-airflow-providers-google
<https://airflow.apache.org/docs/apache-airflow-providers-google>`_
``google``
`apache-airflow-providers-openlineage
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_
``openlineage``
==================================================================================================================
=================
diff --git a/providers/databricks/pyproject.toml
b/providers/databricks/pyproject.toml
index f9a777eb085..42837674e09 100644
--- a/providers/databricks/pyproject.toml
+++ b/providers/databricks/pyproject.toml
@@ -93,6 +93,12 @@ dependencies = [
"sqlalchemy" = [
"databricks-sqlalchemy>=1.0.2",
]
+"google" = [
+ "apache-airflow-providers-google>=10.24.0"
+]
+"avro" = [
+ "fastavro>=1.9.0"
+]
[dependency-groups]
dev = [
@@ -101,6 +107,7 @@ dev = [
"apache-airflow-devel-common",
"apache-airflow-providers-common-compat",
"apache-airflow-providers-common-sql",
+ "apache-airflow-providers-google",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
# Need to exclude 1.3.0 due to missing aarch64 binaries, fixed with 1.3.1++
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index 999b173540c..840fa7bc8b8 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -21,13 +21,20 @@ from __future__ import annotations
import csv
import json
+import os
from collections.abc import Sequence
from functools import cached_property
+from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, ClassVar
+from urllib.parse import urlparse
from databricks.sql.utils import ParamEscaper
-from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
+from airflow.providers.common.compat.sdk import (
+ AirflowException,
+ AirflowOptionalProviderFeatureException,
+ BaseOperator,
+)
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
@@ -62,13 +69,27 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param catalog: An optional initial catalog to use. Requires DBR version
9.0+ (templated)
:param schema: An optional initial schema to use. Requires DBR version
9.0+ (templated)
:param output_path: optional string specifying the file to which write
selected data. (templated)
- :param output_format: format of output data if ``output_path` is specified.
- Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
+ Supports local file paths and GCS URIs (e.g.,
``gs://bucket/path/file.parquet``).
+ When using GCS URIs, requires the ``apache-airflow-providers-google``
package.
+ :param output_format: format of output data if ``output_path`` is
specified.
+ Possible values are ``csv``, ``json``, ``jsonl``, ``parquet``,
``avro``. Default is ``csv``.
:param csv_params: parameters that will be passed to the
``csv.DictWriter`` class used to write CSV data.
+ :param gcp_conn_id: The connection ID to use for connecting to Google
Cloud when using GCS output path.
+ Default is ``google_cloud_default``.
+ :param gcs_impersonation_chain: Optional service account to impersonate
using short-term
+ credentials for GCS upload, or chained list of accounts required to
get the access_token
+ of the last account in the list, which will be impersonated in the
request. (templated)
"""
template_fields: Sequence[str] = tuple(
- {"_output_path", "schema", "catalog", "http_headers",
"databricks_conn_id"}
+ {
+ "_output_path",
+ "schema",
+ "catalog",
+ "http_headers",
+ "databricks_conn_id",
+ "_gcs_impersonation_chain",
+ }
| set(SQLExecuteQueryOperator.template_fields)
)
@@ -90,6 +111,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
client_parameters: dict[str, Any] | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ gcs_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=databricks_conn_id, **kwargs)
@@ -105,6 +128,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
+ self._gcp_conn_id = gcp_conn_id
+ self._gcs_impersonation_chain = gcs_impersonation_chain
@cached_property
def _hook(self) -> DatabricksSqlHook:
@@ -127,41 +152,151 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
def _should_run_output_processing(self) -> bool:
return self.do_xcom_push or bool(self._output_path)
+ @property
+ def _is_gcs_output(self) -> bool:
+ """Check if the output path is a GCS URI."""
+ return self._output_path.startswith("gs://") if self._output_path else
False
+
+ def _parse_gcs_path(self, path: str) -> tuple[str, str]:
+ """Parse a GCS URI into bucket and object name."""
+ parsed = urlparse(path)
+ bucket = parsed.netloc
+ object_name = parsed.path.lstrip("/")
+ return bucket, object_name
+
+ def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
+ """Upload a local file to GCS."""
+ try:
+ from airflow.providers.google.cloud.hooks.gcs import GCSHook
+ except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "The 'apache-airflow-providers-google' package is required for
GCS output. "
+ "Install it with: pip install apache-airflow-providers-google"
+ )
+
+ bucket, object_name = self._parse_gcs_path(gcs_path)
+ hook = GCSHook(
+ gcp_conn_id=self._gcp_conn_id,
+ impersonation_chain=self._gcs_impersonation_chain,
+ )
+ hook.upload(
+ bucket_name=bucket,
+ object_name=object_name,
+ filename=local_path,
+ )
+ self.log.info("Uploaded output to %s", gcs_path)
+
+ def _write_parquet(self, file_path: str, field_names: list[str], rows:
list[Any]) -> None:
+ """Write data to a Parquet file."""
+ import pyarrow as pa
+ import pyarrow.parquet as pq
+
+ data: dict[str, list] = {name: [] for name in field_names}
+ for row in rows:
+ row_dict = row._asdict()
+ for name in field_names:
+ data[name].append(row_dict[name])
+
+ table = pa.Table.from_pydict(data)
+ pq.write_table(table, file_path)
+
+ def _write_avro(self, file_path: str, field_names: list[str], rows:
list[Any]) -> None:
+ """Write data to an Avro file using fastavro."""
+ try:
+ from fastavro import writer
+ except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "The 'fastavro' package is required for Avro output. Install
it with: pip install fastavro"
+ )
+
+ data: dict[str, list] = {name: [] for name in field_names}
+ for row in rows:
+ row_dict = row._asdict()
+ for name in field_names:
+ data[name].append(row_dict[name])
+
+ schema_fields = []
+ for name in field_names:
+ sample_val = next(
+ (data[name][i] for i in range(len(data[name])) if
data[name][i] is not None), None
+ )
+ if sample_val is None:
+ avro_type = ["null", "string"]
+ elif isinstance(sample_val, bool):
+ avro_type = ["null", "boolean"]
+ elif isinstance(sample_val, int):
+ avro_type = ["null", "long"]
+ elif isinstance(sample_val, float):
+ avro_type = ["null", "double"]
+ else:
+ avro_type = ["null", "string"]
+ schema_fields.append({"name": name, "type": avro_type})
+
+ avro_schema = {
+ "type": "record",
+ "name": "QueryResult",
+ "fields": schema_fields,
+ }
+
+ records = [row._asdict() for row in rows]
+ with open(file_path, "wb") as f:
+ writer(f, avro_schema, records)
+
def _process_output(self, results: list[Any], descriptions:
list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
return list(zip(descriptions, results))
if not self._output_format:
raise AirflowException("Output format should be specified!")
- # Output to a file only the result of last query
+
last_description = descriptions[-1]
last_results = results[-1]
if last_description is None:
- raise AirflowException("There is missing description present for
the output file. .")
+ raise AirflowException("There is missing description present for
the output file.")
field_names = [field[0] for field in last_description]
- if self._output_format.lower() == "csv":
- with open(self._output_path, "w", newline="") as file:
- if self._csv_params:
- csv_params = self._csv_params
- else:
- csv_params = {}
- write_header = csv_params.get("header", True)
- if "header" in csv_params:
- del csv_params["header"]
- writer = csv.DictWriter(file, fieldnames=field_names,
**csv_params)
- if write_header:
- writer.writeheader()
- for row in last_results:
- writer.writerow(row._asdict())
- elif self._output_format.lower() == "json":
- with open(self._output_path, "w") as file:
- file.write(json.dumps([row._asdict() for row in last_results]))
- elif self._output_format.lower() == "jsonl":
- with open(self._output_path, "w") as file:
- for row in last_results:
- file.write(json.dumps(row._asdict()))
- file.write("\n")
+
+ if self._is_gcs_output:
+ suffix = f".{self._output_format.lower()}"
+ tmp_file = NamedTemporaryFile(mode="w", suffix=suffix,
delete=False, newline="")
+ local_path = tmp_file.name
+ tmp_file.close()
else:
- raise AirflowException(f"Unsupported output format:
'{self._output_format}'")
+ local_path = self._output_path
+
+ try:
+ output_format = self._output_format.lower()
+ if output_format == "csv":
+ with open(local_path, "w", newline="") as file:
+ if self._csv_params:
+ csv_params = self._csv_params.copy()
+ else:
+ csv_params = {}
+ write_header = csv_params.pop("header", True)
+ writer = csv.DictWriter(file, fieldnames=field_names,
**csv_params)
+ if write_header:
+ writer.writeheader()
+ for row in last_results:
+ writer.writerow(row._asdict())
+ elif output_format == "json":
+ with open(local_path, "w") as file:
+ file.write(json.dumps([row._asdict() for row in
last_results]))
+ elif output_format == "jsonl":
+ with open(local_path, "w") as file:
+ for row in last_results:
+ file.write(json.dumps(row._asdict()))
+ file.write("\n")
+ elif output_format == "parquet":
+ self._write_parquet(local_path, field_names, last_results)
+ elif output_format == "avro":
+ self._write_avro(local_path, field_names, last_results)
+ else:
+ raise ValueError(f"Unsupported output format:
'{self._output_format}'")
+
+ if self._is_gcs_output:
+ self._upload_to_gcs(local_path, self._output_path)
+ finally:
+ if self._is_gcs_output and os.path.exists(local_path):
+ os.unlink(local_path)
+
return list(zip(descriptions, results))
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
index 58cc0e59607..e216c56bea2 100644
---
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
+++
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
@@ -314,3 +314,142 @@ def test_hook_is_cached():
hook = op.get_db_hook()
hook2 = op.get_db_hook()
assert hook is hook2
+
+
+def test_exec_write_parquet_file(tmp_path):
+ """Test writing output to Parquet format."""
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class:
+ path = tmp_path / "testfile.parquet"
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="select * from dummy",
+ output_path=os.fspath(path),
+ output_format="parquet",
+ return_last=True,
+ do_xcom_push=True,
+ split_statements=False,
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.run.return_value = [SerializableRow(1, "value1"),
SerializableRow(2, "value2")]
+ db_mock.descriptions = [[("id",), ("value",)]]
+
+ op.execute(None)
+
+ import pyarrow.parquet as pq
+
+ table = pq.read_table(path)
+ assert table.num_rows == 2
+ assert table.column_names == ["id", "value"]
+ assert table.column("id").to_pylist() == [1, 2]
+ assert table.column("value").to_pylist() == ["value1", "value2"]
+
+
+def test_exec_write_avro_file_with_fastavro(tmp_path):
+ """Test writing output to Avro format using fastavro."""
+ pytest.importorskip("fastavro")
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class:
+ path = tmp_path / "testfile.avro"
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="select * from dummy",
+ output_path=os.fspath(path),
+ output_format="avro",
+ return_last=True,
+ do_xcom_push=True,
+ split_statements=False,
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.run.return_value = [SerializableRow(1, "value1"),
SerializableRow(2, "value2")]
+ db_mock.descriptions = [[("id",), ("value",)]]
+
+ op.execute(None)
+
+ from fastavro import reader
+
+ with open(path, "rb") as f:
+ records = list(reader(f))
+ assert len(records) == 2
+ assert records[0] == {"id": 1, "value": "value1"}
+ assert records[1] == {"id": 2, "value": "value2"}
+
+
+def test_exec_write_gcs_output(tmp_path):
+ """Test writing output to GCS."""
+ with (
+
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class,
+ patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as
gcs_mock_class,
+ ):
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="select * from dummy",
+ output_path="gs://my-bucket/path/to/output.csv",
+ output_format="csv",
+ return_last=True,
+ do_xcom_push=True,
+ split_statements=False,
+ gcp_conn_id="my_gcp_conn",
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.run.return_value = [SerializableRow(1, "value1"),
SerializableRow(2, "value2")]
+ db_mock.descriptions = [[("id",), ("value",)]]
+
+ op.execute(None)
+
+ gcs_mock_class.assert_called_once_with(
+ gcp_conn_id="my_gcp_conn",
+ impersonation_chain=None,
+ )
+ gcs_mock_class.return_value.upload.assert_called_once()
+ call_kwargs = gcs_mock_class.return_value.upload.call_args[1]
+ assert call_kwargs["bucket_name"] == "my-bucket"
+ assert call_kwargs["object_name"] == "path/to/output.csv"
+
+
+def test_exec_write_gcs_parquet_output(tmp_path):
+ """Test writing Parquet output to GCS."""
+ with (
+
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class,
+ patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as
gcs_mock_class,
+ ):
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="select * from dummy",
+ output_path="gs://my-bucket/data/results.parquet",
+ output_format="parquet",
+ return_last=True,
+ do_xcom_push=True,
+ split_statements=False,
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.run.return_value = [SerializableRow(1, "value1"),
SerializableRow(2, "value2")]
+ db_mock.descriptions = [[("id",), ("value",)]]
+
+ op.execute(None)
+
+ gcs_mock_class.return_value.upload.assert_called_once()
+ call_kwargs = gcs_mock_class.return_value.upload.call_args[1]
+ assert call_kwargs["bucket_name"] == "my-bucket"
+ assert call_kwargs["object_name"] == "data/results.parquet"
+
+
+def test_is_gcs_output():
+ """Test _is_gcs_output property."""
+ op_gcs = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1",
output_path="gs://bucket/path")
+ assert op_gcs._is_gcs_output is True
+
+ op_local = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1",
output_path="/local/path")
+ assert op_local._is_gcs_output is False
+
+ op_s3 = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1",
output_path="s3://bucket/path")
+ assert op_s3._is_gcs_output is False
+
+ op_none = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+ assert op_none._is_gcs_output is False
+
+
+def test_parse_gcs_path():
+ """Test _parse_gcs_path method."""
+ op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+ bucket, object_name =
op._parse_gcs_path("gs://my-bucket/path/to/file.parquet")
+ assert bucket == "my-bucket"
+ assert object_name == "path/to/file.parquet"