This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7b3bf4e43558999af29a4ce7f60f2f9ef55f2ebf
Author: Alex Ott <[email protected]>
AuthorDate: Sun Mar 13 17:34:05 2022 +0100

    DatabricksSqlOperator - switch to databricks-sql-connector 2.x
---
 .../providers/databricks/hooks/databricks_sql.py   | 11 +++++++++-
 .../databricks/operators/databricks_sql.py         | 24 +++++++++++++++++++---
 docs/apache-airflow-providers-databricks/index.rst |  3 ++-
 .../operators/copy_into.rst                        |  2 ++
 .../operators/sql.rst                              |  6 ++++++
 .../databricks/operators/test_databricks_sql.py    | 16 +++++++++++++--
 6 files changed, 55 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks_sql.py 
b/airflow/providers/databricks/hooks/databricks_sql.py
index e77fb3cf1b..14d6e4bf68 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -18,7 +18,7 @@
 import re
 from contextlib import closing
 from copy import copy
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from databricks import sql  # type: ignore[attr-defined]
 from databricks.sql.client import Connection  # type: ignore[attr-defined]
@@ -42,6 +42,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         be provided as described above.
     :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
         If not specified, it could be specified in the Databricks connection's 
extra parameters.
+    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
+    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
+    :param schema: An optional initial schema to use. Requires DBR version 9.0+
     """
 
     hook_name = 'Databricks SQL'
@@ -52,6 +55,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         http_path: Optional[str] = None,
         sql_endpoint_name: Optional[str] = None,
         session_configuration: Optional[Dict[str, str]] = None,
+        metadata: Optional[List[Tuple[str, str]]] = None,
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
     ) -> None:
         super().__init__(databricks_conn_id)
         self._sql_conn = None
@@ -60,6 +66,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         self._sql_endpoint_name = sql_endpoint_name
         self.supports_autocommit = True
         self.session_config = session_configuration
+        self.metadata = metadata
+        self.catalog = catalog
+        self.schema = schema
 
     def _get_extra_config(self) -> Dict[str, Optional[Any]]:
         extra_params = copy(self.databricks_conn.extra_dejson)
diff --git a/airflow/providers/databricks/operators/databricks_sql.py 
b/airflow/providers/databricks/operators/databricks_sql.py
index 48a6713517..a7d157932d 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -20,9 +20,9 @@
 
 import csv
 import json
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, 
Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, 
Optional, Sequence, Tuple, Union
 
-from databricks.sql.common import ParamEscaper
+from databricks.sql.utils import ParamEscaper
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -53,13 +53,16 @@ class DatabricksSqlOperator(BaseOperator):
     :param parameters: (optional) the parameters to render the SQL query with.
     :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
         If not specified, it could be specified in the Databricks connection's 
extra parameters.
+    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
+    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
+    :param schema: An optional initial schema to use. Requires DBR version 9.0+
     :param output_path: optional string specifying the file to which write 
selected data. (templated)
     :param output_format: format of output data if ``output_path` is specified.
         Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
     :param csv_params: parameters that will be passed to the 
``csv.DictWriter`` class used to write CSV data.
     """
 
-    template_fields: Sequence[str] = ('sql', '_output_path')
+    template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 
'metadata')
     template_ext: Sequence[str] = ('.sql',)
     template_fields_renderers = {'sql': 'sql'}
 
@@ -72,6 +75,9 @@ class DatabricksSqlOperator(BaseOperator):
         sql_endpoint_name: Optional[str] = None,
         parameters: Optional[Union[Mapping, Iterable]] = None,
         session_configuration=None,
+        metadata: Optional[List[Tuple[str, str]]] = None,
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
         do_xcom_push: bool = False,
         output_path: Optional[str] = None,
         output_format: str = 'csv',
@@ -90,6 +96,9 @@ class DatabricksSqlOperator(BaseOperator):
         self.parameters = parameters
         self.do_xcom_push = do_xcom_push
         self.session_config = session_configuration
+        self.metadata = metadata
+        self.catalog = catalog
+        self.schema = schema
 
     def _get_hook(self) -> DatabricksSqlHook:
         return DatabricksSqlHook(
@@ -97,6 +106,9 @@ class DatabricksSqlOperator(BaseOperator):
             http_path=self._http_path,
             session_configuration=self.session_config,
             sql_endpoint_name=self._sql_endpoint_name,
+            metadata=self.metadata,
+            catalog=self.catalog,
+            schema=self.schema,
         )
 
     def _format_output(self, schema, results):
@@ -165,6 +177,9 @@ class DatabricksCopyIntoOperator(BaseOperator):
         or ``sql_endpoint_name`` must be specified.
     :param sql_endpoint_name: Optional name of Databricks SQL Endpoint.
         If not specified, ``http_path`` must be provided as described above.
+    :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
+        If not specified, it could be specified in the Databricks connection's 
extra parameters.
+    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
     :param files: optional list of files to import. Can't be specified 
together with ``pattern``. (templated)
     :param pattern: optional regex string to match file names to import.
         Can't be specified together with ``files``.
@@ -195,6 +210,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
         http_path: Optional[str] = None,
         sql_endpoint_name: Optional[str] = None,
         session_configuration=None,
+        metadata: Optional[List[Tuple[str, str]]] = None,
         files: Optional[List[str]] = None,
         pattern: Optional[str] = None,
         expression_list: Optional[str] = None,
@@ -231,6 +247,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
         self._format_options = format_options
         self._copy_options = copy_options or {}
         self._validate = validate
+        self.metadata = metadata
         if force_copy is not None:
             self._copy_options["force"] = 'true' if force_copy else 'false'
 
@@ -240,6 +257,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
             http_path=self._http_path,
             session_configuration=self.session_config,
             sql_endpoint_name=self._sql_endpoint_name,
+            metadata=self.metadata,
         )
 
     @staticmethod
diff --git a/docs/apache-airflow-providers-databricks/index.rst 
b/docs/apache-airflow-providers-databricks/index.rst
index 51c8381f0c..968b94149b 100644
--- a/docs/apache-airflow-providers-databricks/index.rst
+++ b/docs/apache-airflow-providers-databricks/index.rst
@@ -80,7 +80,8 @@ PIP requirements
 PIP package                   Version required
 ============================  ===================
 ``apache-airflow``            ``>=2.1.0``
-``databricks-sql-connector``  ``>=1.0.2, <2.0.0``
+``databricks-sql-connector``  ``>=2.0.0, <3.0.0``
+>>>>>>> DatabricksSqlOperator - switch to databricks-sql-connector 2.x
 ``requests``                  ``>=2.26.0, <3``
 ============================  ===================
 
diff --git a/docs/apache-airflow-providers-databricks/operators/copy_into.rst 
b/docs/apache-airflow-providers-databricks/operators/copy_into.rst
index c2db25992f..1b73bac8d2 100644
--- a/docs/apache-airflow-providers-databricks/operators/copy_into.rst
+++ b/docs/apache-airflow-providers-databricks/operators/copy_into.rst
@@ -49,6 +49,8 @@ Operator loads data from a specified location into a table 
using a configured en
      - Optional HTTP path for Databricks SQL endpoint or Databricks cluster. 
If not specified, it should be provided in Databricks connection, or the 
``sql_endpoint_name`` parameter must be set.
    * - session_configuration: dict[str,str]
      - optional dict specifying Spark configuration parameters that will be 
set for the session.
+   * - metadata: list[tuple[str, str]]
+     - Optional list of (k, v) pairs that will be set as Http headers on every 
request
    * - files: Optional[List[str]]
      - optional list of files to import. Can't be specified together with 
``pattern``.
    * - pattern: Optional[str]
diff --git a/docs/apache-airflow-providers-databricks/operators/sql.rst 
b/docs/apache-airflow-providers-databricks/operators/sql.rst
index d47af44d6c..953bbae245 100644
--- a/docs/apache-airflow-providers-databricks/operators/sql.rst
+++ b/docs/apache-airflow-providers-databricks/operators/sql.rst
@@ -51,6 +51,12 @@ Operator executes given SQL queries against configured 
endpoint.  There are 3 wa
      - Optional parameters that will be used to substitute variable(s) in SQL 
query.
    * - session_configuration: dict[str,str]
      - optional dict specifying Spark configuration parameters that will be 
set for the session.
+   * - metadata: list[tuple[str, str]]
+     - Optional list of (k, v) pairs that will be set as Http headers on every 
request
+   * - catalog: str
+     - Optional initial catalog to use. Requires DBR version 9.0+
+   * - schema: str
+     - Optional initial schema to use. Requires DBR version 9.0+
    * - output_path: str
      - Optional path to the file to which results will be written.
    * - output_format: str
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py 
b/tests/providers/databricks/operators/test_databricks_sql.py
index 6b9fb43701..afa25932a9 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -53,7 +53,13 @@ class TestDatabricksSqlOperator(unittest.TestCase):
 
         assert results == mock_results
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, http_path=None, session_configuration=None, 
sql_endpoint_name=None
+            DEFAULT_CONN_ID,
+            http_path=None,
+            session_configuration=None,
+            sql_endpoint_name=None,
+            metadata=None,
+            catalog=None,
+            schema=None,
         )
         db_mock.run.assert_called_once_with(sql, parameters=None)
 
@@ -78,7 +84,13 @@ class TestDatabricksSqlOperator(unittest.TestCase):
 
         assert results == ["id,value", "1,value1"]
         db_mock_class.assert_called_once_with(
-            DEFAULT_CONN_ID, http_path=None, session_configuration=None, 
sql_endpoint_name=None
+            DEFAULT_CONN_ID,
+            http_path=None,
+            session_configuration=None,
+            sql_endpoint_name=None,
+            metadata=None,
+            catalog=None,
+            schema=None,
         )
         db_mock.run.assert_called_once_with(sql, parameters=None)
 

Reply via email to