This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6a3d6cc32b4e3922d259c889460fe82e0ebf3663
Author: Alex Ott <[email protected]>
AuthorDate: Sat Apr 23 12:38:06 2022 +0200

    Update to the released version of DBSQL connector
    
    Also added additional parameters for further customization of connection
    if it's required
---
 .../providers/databricks/hooks/databricks_sql.py   | 47 ++++++++++++-------
 .../databricks/operators/databricks_sql.py         | 53 ++++++++++++++++------
 .../operators/copy_into.rst                        | 28 +++++++-----
 .../operators/sql.rst                              |  8 ++--
 setup.py                                           |  2 +-
 .../databricks/operators/test_databricks_sql.py    | 23 +++++++++-
 6 files changed, 112 insertions(+), 49 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks_sql.py 
b/airflow/providers/databricks/hooks/databricks_sql.py
index 14d6e4bf68..aa8245772a 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -23,29 +23,17 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from databricks import sql  # type: ignore[attr-defined]
 from databricks.sql.client import Connection  # type: ignore[attr-defined]
 
+from airflow import __version__
 from airflow.exceptions import AirflowException
 from airflow.hooks.dbapi import DbApiHook
 from airflow.providers.databricks.hooks.databricks_base import 
BaseDatabricksHook
 
 LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints')
+USER_AGENT_STRING = f'airflow-{__version__}'
 
 
 class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
-    """
-    Interact with Databricks SQL.
-
-    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
-    :param http_path: Optional string specifying HTTP path of Databricks SQL 
Endpoint or cluster.
-        If not specified, it should be either specified in the Databricks 
connection's extra parameters,
-        or ``sql_endpoint_name`` must be specified.
-    :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not 
specified, ``http_path`` must
-        be provided as described above.
-    :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
-        If not specified, it could be specified in the Databricks connection's 
extra parameters.
-    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
-    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
-    :param schema: An optional initial schema to use. Requires DBR version 9.0+
-    """
+    """Hook to interact with Databricks SQL."""
 
     hook_name = 'Databricks SQL'
 
@@ -55,10 +43,29 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         http_path: Optional[str] = None,
         sql_endpoint_name: Optional[str] = None,
         session_configuration: Optional[Dict[str, str]] = None,
-        metadata: Optional[List[Tuple[str, str]]] = None,
+        http_headers: Optional[List[Tuple[str, str]]] = None,
         catalog: Optional[str] = None,
         schema: Optional[str] = None,
+        **kwargs,
     ) -> None:
+        """
+        Initializes DatabricksSqlHook
+
+        :param databricks_conn_id: Reference to the
+            :ref:`Databricks connection <howto/connection:databricks>`.
+        :param http_path: Optional string specifying HTTP path of Databricks 
SQL Endpoint or cluster.
+            If not specified, it should be either specified in the Databricks 
connection's extra parameters,
+            or ``sql_endpoint_name`` must be specified.
+        :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If 
not specified, ``http_path``
+            must be provided as described above.
+        :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
+            If not specified, it could be specified in the Databricks 
connection's extra parameters.
+        :param http_headers: An optional list of (k, v) pairs that will be set 
as HTTP headers
+            on every request
+        :param catalog: An optional initial catalog to use. Requires DBR 
version 9.0+
+        :param schema: An optional initial schema to use. Requires DBR version 
9.0+
+        :param kwargs: Additional parameters internal to Databricks SQL 
Connector parameters
+        """
         super().__init__(databricks_conn_id)
         self._sql_conn = None
         self._token: Optional[str] = None
@@ -66,9 +73,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         self._sql_endpoint_name = sql_endpoint_name
         self.supports_autocommit = True
         self.session_config = session_configuration
-        self.metadata = metadata
+        self.http_headers = http_headers
         self.catalog = catalog
         self.schema = schema
+        self.additional_params = kwargs
 
     def _get_extra_config(self) -> Dict[str, Optional[Any]]:
         extra_params = copy(self.databricks_conn.extra_dejson)
@@ -122,8 +130,13 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
                 self.host,
                 self._http_path,
                 self._token,
+                schema=self.schema,
+                catalog=self.catalog,
                 session_configuration=self.session_config,
+                http_headers=self.http_headers,
+                _user_agent_entry=USER_AGENT_STRING,
                 **self._get_extra_config(),
+                **self.additional_params,
             )
         return self._sql_conn
 
diff --git a/airflow/providers/databricks/operators/databricks_sql.py 
b/airflow/providers/databricks/operators/databricks_sql.py
index a7d157932d..9e6298bc21 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -53,16 +53,18 @@ class DatabricksSqlOperator(BaseOperator):
     :param parameters: (optional) the parameters to render the SQL query with.
     :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
         If not specified, it could be specified in the Databricks connection's 
extra parameters.
-    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
-    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
-    :param schema: An optional initial schema to use. Requires DBR version 9.0+
+    :param client_parameters: Additional parameters internal to Databricks SQL 
Connector parameters
+    :param http_headers: An optional list of (k, v) pairs that will be set as 
HTTP headers on every request.
+         (templated)
+    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+ (templated)
+    :param schema: An optional initial schema to use. Requires DBR version 
9.0+ (templated)
     :param output_path: optional string specifying the file to which write 
selected data. (templated)
     :param output_format: format of output data if ``output_path` is specified.
         Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
     :param csv_params: parameters that will be passed to the 
``csv.DictWriter`` class used to write CSV data.
     """
 
-    template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 
'metadata')
+    template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 
'catalog', 'http_headers')
     template_ext: Sequence[str] = ('.sql',)
     template_fields_renderers = {'sql': 'sql'}
 
@@ -75,13 +77,14 @@ class DatabricksSqlOperator(BaseOperator):
         sql_endpoint_name: Optional[str] = None,
         parameters: Optional[Union[Mapping, Iterable]] = None,
         session_configuration=None,
-        metadata: Optional[List[Tuple[str, str]]] = None,
+        http_headers: Optional[List[Tuple[str, str]]] = None,
         catalog: Optional[str] = None,
         schema: Optional[str] = None,
         do_xcom_push: bool = False,
         output_path: Optional[str] = None,
         output_format: str = 'csv',
         csv_params: Optional[Dict[str, Any]] = None,
+        client_parameters: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> None:
         """Creates a new ``DatabricksSqlOperator``."""
@@ -96,9 +99,10 @@ class DatabricksSqlOperator(BaseOperator):
         self.parameters = parameters
         self.do_xcom_push = do_xcom_push
         self.session_config = session_configuration
-        self.metadata = metadata
+        self.http_headers = http_headers
         self.catalog = catalog
         self.schema = schema
+        self.client_parameters = client_parameters or {}
 
     def _get_hook(self) -> DatabricksSqlHook:
         return DatabricksSqlHook(
@@ -106,9 +110,10 @@ class DatabricksSqlOperator(BaseOperator):
             http_path=self._http_path,
             session_configuration=self.session_config,
             sql_endpoint_name=self._sql_endpoint_name,
-            metadata=self.metadata,
+            http_headers=self.http_headers,
             catalog=self.catalog,
             schema=self.schema,
+            **self.client_parameters,
         )
 
     def _format_output(self, schema, results):
@@ -179,12 +184,16 @@ class DatabricksCopyIntoOperator(BaseOperator):
         If not specified, ``http_path`` must be provided as described above.
     :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
         If not specified, it could be specified in the Databricks connection's 
extra parameters.
-    :param metadata: An optional list of (k, v) pairs that will be set as Http 
headers on every request
+    :param http_headers: An optional list of (k, v) pairs that will be set as 
HTTP headers on every request
+    :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
+    :param schema: An optional initial schema to use. Requires DBR version 9.0+
+    :param client_parameters: Additional parameters internal to Databricks SQL 
Connector parameters
     :param files: optional list of files to import. Can't be specified 
together with ``pattern``. (templated)
     :param pattern: optional regex string to match file names to import.
         Can't be specified together with ``files``.
     :param expression_list: optional string that will be used in the 
``SELECT`` expression.
-    :param credential: optional credential configuration for authentication 
against a specified location.
+    :param credential: optional credential configuration for authentication 
against a source location.
+    :param storage_credential: optional Unity Catalog storage credential for 
destination.
     :param encryption: optional encryption configuration for a specified 
location.
     :param format_options: optional dictionary with options specific for a 
given file format.
     :param force_copy: optional bool to control forcing of data import
@@ -210,11 +219,15 @@ class DatabricksCopyIntoOperator(BaseOperator):
         http_path: Optional[str] = None,
         sql_endpoint_name: Optional[str] = None,
         session_configuration=None,
-        metadata: Optional[List[Tuple[str, str]]] = None,
+        http_headers: Optional[List[Tuple[str, str]]] = None,
+        client_parameters: Optional[Dict[str, Any]] = None,
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
         files: Optional[List[str]] = None,
         pattern: Optional[str] = None,
         expression_list: Optional[str] = None,
         credential: Optional[Dict[str, str]] = None,
+        storage_credential: Optional[str] = None,
         encryption: Optional[Dict[str, str]] = None,
         format_options: Optional[Dict[str, str]] = None,
         force_copy: Optional[bool] = None,
@@ -240,14 +253,18 @@ class DatabricksCopyIntoOperator(BaseOperator):
         self._sql_endpoint_name = sql_endpoint_name
         self.session_config = session_configuration
         self._table_name = table_name
+        self._catalog = catalog
+        self._schema = schema
         self._file_location = file_location
         self._expression_list = expression_list
         self._credential = credential
+        self._storage_credential = storage_credential
         self._encryption = encryption
         self._format_options = format_options
         self._copy_options = copy_options or {}
         self._validate = validate
-        self.metadata = metadata
+        self._http_headers = http_headers
+        self._client_parameters = client_parameters or {}
         if force_copy is not None:
             self._copy_options["force"] = 'true' if force_copy else 'false'
 
@@ -257,7 +274,10 @@ class DatabricksCopyIntoOperator(BaseOperator):
             http_path=self._http_path,
             session_configuration=self.session_config,
             sql_endpoint_name=self._sql_endpoint_name,
-            metadata=self.metadata,
+            http_headers=self._http_headers,
+            catalog=self._catalog,
+            schema=self._schema,
+            **self._client_parameters,
         )
 
     @staticmethod
@@ -298,6 +318,9 @@ class DatabricksCopyIntoOperator(BaseOperator):
             files_or_pattern = f"FILES = {escaper.escape_item(self._files)}\n"
         format_options = self._generate_options("FORMAT_OPTIONS", escaper, 
self._format_options) + "\n"
         copy_options = self._generate_options("COPY_OPTIONS", escaper, 
self._copy_options) + "\n"
+        storage_cred = ""
+        if self._storage_credential:
+            storage_cred = f" WITH (CREDENTIAL {self._storage_credential})"
         validation = ""
         if self._validate is not None:
             if isinstance(self._validate, bool):
@@ -310,9 +333,11 @@ class DatabricksCopyIntoOperator(BaseOperator):
                     )
                 validation = f"VALIDATE {self._validate} ROWS\n"
             else:
-                raise AirflowException("Incorrect data type for validate 
parameter: " + type(self._validate))
+                raise AirflowException(
+                    "Incorrect data type for validate parameter: " + 
str(type(self._validate))
+                )
         # TODO: think on how to make sure that table_name and expression_list 
aren't used for SQL injection
-        sql = f"""COPY INTO {self._table_name}
+        sql = f"""COPY INTO {self._table_name}{storage_cred}
 FROM {location}
 FILEFORMAT = {self._file_format}
 {validation}{files_or_pattern}{format_options}{copy_options}
diff --git a/docs/apache-airflow-providers-databricks/operators/copy_into.rst 
b/docs/apache-airflow-providers-databricks/operators/copy_into.rst
index 1b73bac8d2..1d4ef07de2 100644
--- a/docs/apache-airflow-providers-databricks/operators/copy_into.rst
+++ b/docs/apache-airflow-providers-databricks/operators/copy_into.rst
@@ -49,25 +49,29 @@ Operator loads data from a specified location into a table 
using a configured en
      - Optional HTTP path for Databricks SQL endpoint or Databricks cluster. 
If not specified, it should be provided in Databricks connection, or the 
``sql_endpoint_name`` parameter must be set.
    * - session_configuration: dict[str,str]
      - optional dict specifying Spark configuration parameters that will be 
set for the session.
-   * - metadata: list[tuple[str, str]]
-     - Optional list of (k, v) pairs that will be set as Http headers on every 
request
-   * - files: Optional[List[str]]
+   * - http_headers: list[tuple[str, str]]
+     - Optional list of (k, v) pairs that will be set as HTTP headers on every 
request
+   * - client_parameters: dict[str,str]
+     - optional additional parameters internal to Databricks SQL Connector 
parameters
+   * - files: list[str]]
      - optional list of files to import. Can't be specified together with 
``pattern``.
-   * - pattern: Optional[str]
+   * - pattern: str
      - optional regex string to match file names to import. Can't be specified 
together with ``files``.
-   * - expression_list: Optional[str]
+   * - expression_list: str
      - optional string that will be used in the ``SELECT`` expression.
-   * - credential: Optional[Dict[str, str]]
+   * - credential: dict[str, str]
      - optional credential configuration for authentication against a 
specified location
-   * - encryption: Optional[Dict[str, str]]
+   * - encryption: dict[str, str]
      - optional encryption configuration for a specified location
-   * - format_options: Optional[Dict[str, str]]
+   * - storage_credential: str
+     - optional Unity Catalog storage credential name for the target table
+   * - format_options: dict[str, str]
      - optional dictionary with options specific for a given file format.
-   * - force_copy: Optional[bool]
-     - optional bool to control forcing of data import (could be also 
specified in ``copy_options``).
-   * - copy_options: Optional[Dict[str, str]]
+   * - force_copy: bool
+     - optional boolean parameter to control forcing of data import (could be 
also specified in ``copy_options``).
+   * - copy_options: dict[str, str]
      - optional dictionary of copy options. Right now only ``force`` option is 
supported.
-   * - validate: Optional[Union[bool, int]]
+   * - validate: union[bool, int]]
      - optional validation configuration. ``True`` forces validation of all 
rows, positive number - only N first rows. (requires Preview channel)
 
 Examples
diff --git a/docs/apache-airflow-providers-databricks/operators/sql.rst 
b/docs/apache-airflow-providers-databricks/operators/sql.rst
index 953bbae245..d0a1d6d337 100644
--- a/docs/apache-airflow-providers-databricks/operators/sql.rst
+++ b/docs/apache-airflow-providers-databricks/operators/sql.rst
@@ -51,8 +51,10 @@ Operator executes given SQL queries against configured 
endpoint.  There are 3 wa
      - Optional parameters that will be used to substitute variable(s) in SQL 
query.
    * - session_configuration: dict[str,str]
      - optional dict specifying Spark configuration parameters that will be 
set for the session.
-   * - metadata: list[tuple[str, str]]
-     - Optional list of (k, v) pairs that will be set as Http headers on every 
request
+   * - http_headers: list[tuple[str, str]]
+     - Optional list of (k, v) pairs that will be set as HTTP headers on every 
request
+   * - client_parameters: dict[str,str]
+     - optional additional parameters internal to Databricks SQL Connector 
parameters
    * - catalog: str
      - Optional initial catalog to use. Requires DBR version 9.0+
    * - schema: str
@@ -63,7 +65,7 @@ Operator executes given SQL queries against configured 
endpoint.  There are 3 wa
      - Name of the format which will be used to write results.  Supported 
values are (case-insensitive): ``JSON`` (array of JSON objects), ``JSONL`` 
(each row as JSON object on a separate line), ``CSV`` (default).
    * - csv_params: dict[str, any]
      - Optional dictionary with parameters to customize Python CSV writer.
-   * - do_xcom_push: boolean
+   * - do_xcom_push: bool
      - whether we should push query results (last query if multiple queries 
are provided) to xcom. Default: false
 
 Examples
diff --git a/setup.py b/setup.py
index e305aed6da..aa538bfe69 100644
--- a/setup.py
+++ b/setup.py
@@ -264,7 +264,7 @@ dask = [
 ]
 databricks = [
     'requests>=2.26.0, <3',
-    'databricks-sql-connector>=1.0.2, <2.0.0',
+    'databricks-sql-connector>=2.0.0, <3.0.0',
 ]
 datadog = [
     'datadog>=0.14.0',
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py 
b/tests/providers/databricks/operators/test_databricks_sql.py
index afa25932a9..783fa520a7 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -57,7 +57,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
             http_path=None,
             session_configuration=None,
             sql_endpoint_name=None,
-            metadata=None,
+            http_headers=None,
             catalog=None,
             schema=None,
         )
@@ -88,7 +88,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
             http_path=None,
             session_configuration=None,
             sql_endpoint_name=None,
-            metadata=None,
+            http_headers=None,
             catalog=None,
             schema=None,
         )
@@ -153,6 +153,25 @@ COPY_OPTIONS ('force' = 'true')
             == f"""COPY INTO test
 FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL 
(AZURE_SAS_TOKEN = 'abc') ))
 FILEFORMAT = CSV
+""".strip()
+        )
+
+    def test_copy_with_target_credential(self):
+        expression = "col1, col2"
+        op = DatabricksCopyIntoOperator(
+            file_location=COPY_FILE_LOCATION,
+            file_format='CSV',
+            table_name='test',
+            task_id=TASK_ID,
+            expression_list=expression,
+            storage_credential='abc',
+            credential={'AZURE_SAS_TOKEN': 'abc'},
+        )
+        assert (
+            op._create_sql_query()
+            == f"""COPY INTO test WITH (CREDENTIAL abc)
+FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL 
(AZURE_SAS_TOKEN = 'abc') ))
+FILEFORMAT = CSV
 """.strip()
         )
 

Reply via email to