This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7b3bf4e43558999af29a4ce7f60f2f9ef55f2ebf Author: Alex Ott <[email protected]> AuthorDate: Sun Mar 13 17:34:05 2022 +0100 DatabricksSqlOperator - switch to databricks-sql-connector 2.x --- .../providers/databricks/hooks/databricks_sql.py | 11 +++++++++- .../databricks/operators/databricks_sql.py | 24 +++++++++++++++++++--- docs/apache-airflow-providers-databricks/index.rst | 3 ++- .../operators/copy_into.rst | 2 ++ .../operators/sql.rst | 6 ++++++ .../databricks/operators/test_databricks_sql.py | 16 +++++++++++++-- 6 files changed, 55 insertions(+), 7 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index e77fb3cf1b..14d6e4bf68 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -18,7 +18,7 @@ import re from contextlib import closing from copy import copy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from databricks import sql # type: ignore[attr-defined] from databricks.sql.client import Connection # type: ignore[attr-defined] @@ -42,6 +42,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): be provided as described above. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. + :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ """ hook_name = 'Databricks SQL' @@ -52,6 +55,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, session_configuration: Optional[Dict[str, str]] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> None: super().__init__(databricks_conn_id) self._sql_conn = None @@ -60,6 +66,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): self._sql_endpoint_name = sql_endpoint_name self.supports_autocommit = True self.session_config = session_configuration + self.metadata = metadata + self.catalog = catalog + self.schema = schema def _get_extra_config(self) -> Dict[str, Optional[Any]]: extra_params = copy(self.databricks_conn.extra_dejson) diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index 48a6713517..a7d157932d 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -20,9 +20,9 @@ import csv import json -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union -from databricks.sql.common import ParamEscaper +from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -53,13 +53,16 @@ class DatabricksSqlOperator(BaseOperator): :param parameters: (optional) the parameters to render the SQL query with. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. + :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ :param output_path: optional string specifying the file to which write selected data. (templated) :param output_format: format of output data if ``output_path` is specified. Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``. :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data. """ - template_fields: Sequence[str] = ('sql', '_output_path') + template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 'metadata') template_ext: Sequence[str] = ('.sql',) template_fields_renderers = {'sql': 'sql'} @@ -72,6 +75,9 @@ class DatabricksSqlOperator(BaseOperator): sql_endpoint_name: Optional[str] = None, parameters: Optional[Union[Mapping, Iterable]] = None, session_configuration=None, + metadata: Optional[List[Tuple[str, str]]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, do_xcom_push: bool = False, output_path: Optional[str] = None, output_format: str = 'csv', @@ -90,6 +96,9 @@ class DatabricksSqlOperator(BaseOperator): self.parameters = parameters self.do_xcom_push = do_xcom_push self.session_config = session_configuration + self.metadata = metadata + self.catalog = catalog + self.schema = schema def _get_hook(self) -> DatabricksSqlHook: return DatabricksSqlHook( @@ -97,6 +106,9 @@ class DatabricksSqlOperator(BaseOperator): http_path=self._http_path, session_configuration=self.session_config, sql_endpoint_name=self._sql_endpoint_name, + metadata=self.metadata, + catalog=self.catalog, + schema=self.schema, ) def _format_output(self, schema, results): @@ -165,6 +177,9 @@ class DatabricksCopyIntoOperator(BaseOperator): or ``sql_endpoint_name`` must be specified. :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must be provided as described above. + :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. + If not specified, it could be specified in the Databricks connection's extra parameters. + :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request :param files: optional list of files to import. Can't be specified together with ``pattern``. (templated) :param pattern: optional regex string to match file names to import. Can't be specified together with ``files``. @@ -195,6 +210,7 @@ class DatabricksCopyIntoOperator(BaseOperator): http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, session_configuration=None, + metadata: Optional[List[Tuple[str, str]]] = None, files: Optional[List[str]] = None, pattern: Optional[str] = None, expression_list: Optional[str] = None, @@ -231,6 +247,7 @@ class DatabricksCopyIntoOperator(BaseOperator): self._format_options = format_options self._copy_options = copy_options or {} self._validate = validate + self.metadata = metadata if force_copy is not None: self._copy_options["force"] = 'true' if force_copy else 'false' @@ -240,6 +257,7 @@ class DatabricksCopyIntoOperator(BaseOperator): http_path=self._http_path, session_configuration=self.session_config, sql_endpoint_name=self._sql_endpoint_name, + metadata=self.metadata, ) @staticmethod diff --git a/docs/apache-airflow-providers-databricks/index.rst b/docs/apache-airflow-providers-databricks/index.rst index 51c8381f0c..968b94149b 100644 --- a/docs/apache-airflow-providers-databricks/index.rst +++ b/docs/apache-airflow-providers-databricks/index.rst @@ -80,7 +80,8 @@ PIP requirements PIP package Version required ============================ =================== ``apache-airflow`` ``>=2.1.0`` -``databricks-sql-connector`` ``>=1.0.2, <2.0.0`` +``databricks-sql-connector`` ``>=2.0.0, <3.0.0`` +>>>>>>> DatabricksSqlOperator - switch to databricks-sql-connector 2.x ``requests`` ``>=2.26.0, <3`` ============================ =================== diff --git a/docs/apache-airflow-providers-databricks/operators/copy_into.rst b/docs/apache-airflow-providers-databricks/operators/copy_into.rst index c2db25992f..1b73bac8d2 100644 --- a/docs/apache-airflow-providers-databricks/operators/copy_into.rst +++ b/docs/apache-airflow-providers-databricks/operators/copy_into.rst @@ -49,6 +49,8 @@ Operator loads data from a specified location into a table using a configured en - Optional HTTP path for Databricks SQL endpoint or Databricks cluster. If not specified, it should be provided in Databricks connection, or the ``sql_endpoint_name`` parameter must be set. * - session_configuration: dict[str,str] - optional dict specifying Spark configuration parameters that will be set for the session. + * - metadata: list[tuple[str, str]] + - Optional list of (k, v) pairs that will be set as Http headers on every request * - files: Optional[List[str]] - optional list of files to import. Can't be specified together with ``pattern``. * - pattern: Optional[str] diff --git a/docs/apache-airflow-providers-databricks/operators/sql.rst b/docs/apache-airflow-providers-databricks/operators/sql.rst index d47af44d6c..953bbae245 100644 --- a/docs/apache-airflow-providers-databricks/operators/sql.rst +++ b/docs/apache-airflow-providers-databricks/operators/sql.rst @@ -51,6 +51,12 @@ Operator executes given SQL queries against configured endpoint. There are 3 wa - Optional parameters that will be used to substitute variable(s) in SQL query. * - session_configuration: dict[str,str] - optional dict specifying Spark configuration parameters that will be set for the session. + * - metadata: list[tuple[str, str]] + - Optional list of (k, v) pairs that will be set as Http headers on every request + * - catalog: str + - Optional initial catalog to use. Requires DBR version 9.0+ + * - schema: str + - Optional initial schema to use. Requires DBR version 9.0+ * - output_path: str - Optional path to the file to which results will be written. * - output_format: str diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 6b9fb43701..afa25932a9 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -53,7 +53,13 @@ class TestDatabricksSqlOperator(unittest.TestCase): assert results == mock_results db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, http_path=None, session_configuration=None, sql_endpoint_name=None + DEFAULT_CONN_ID, + http_path=None, + session_configuration=None, + sql_endpoint_name=None, + metadata=None, + catalog=None, + schema=None, ) db_mock.run.assert_called_once_with(sql, parameters=None) @@ -78,7 +84,13 @@ class TestDatabricksSqlOperator(unittest.TestCase): assert results == ["id,value", "1,value1"] db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, http_path=None, session_configuration=None, sql_endpoint_name=None + DEFAULT_CONN_ID, + http_path=None, + session_configuration=None, + sql_endpoint_name=None, + metadata=None, + catalog=None, + schema=None, ) db_mock.run.assert_called_once_with(sql, parameters=None)
