This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 6a3d6cc32b4e3922d259c889460fe82e0ebf3663 Author: Alex Ott <[email protected]> AuthorDate: Sat Apr 23 12:38:06 2022 +0200 Update to the released version of DBSQL connector Also added additional parameters for further customization of connection if it's required --- .../providers/databricks/hooks/databricks_sql.py | 47 ++++++++++++------- .../databricks/operators/databricks_sql.py | 53 ++++++++++++++++------ .../operators/copy_into.rst | 28 +++++++----- .../operators/sql.rst | 8 ++-- setup.py | 2 +- .../databricks/operators/test_databricks_sql.py | 23 +++++++++- 6 files changed, 112 insertions(+), 49 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 14d6e4bf68..aa8245772a 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -23,29 +23,17 @@ from typing import Any, Dict, List, Optional, Tuple, Union from databricks import sql # type: ignore[attr-defined] from databricks.sql.client import Connection # type: ignore[attr-defined] +from airflow import __version__ from airflow.exceptions import AirflowException from airflow.hooks.dbapi import DbApiHook from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints') +USER_AGENT_STRING = f'airflow-{__version__}' class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): - """ - Interact with Databricks SQL. - - :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. - :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. - If not specified, it should be either specified in the Databricks connection's extra parameters, - or ``sql_endpoint_name`` must be specified. - :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` must - be provided as described above. - :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. - If not specified, it could be specified in the Databricks connection's extra parameters. - :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request - :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ - :param schema: An optional initial schema to use. Requires DBR version 9.0+ - """ + """Hook to interact with Databricks SQL.""" hook_name = 'Databricks SQL' @@ -55,10 +43,29 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, session_configuration: Optional[Dict[str, str]] = None, - metadata: Optional[List[Tuple[str, str]]] = None, + http_headers: Optional[List[Tuple[str, str]]] = None, catalog: Optional[str] = None, schema: Optional[str] = None, + **kwargs, ) -> None: + """ + Initializes DatabricksSqlHook + + :param databricks_conn_id: Reference to the + :ref:`Databricks connection <howto/connection:databricks>`. + :param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster. + If not specified, it should be either specified in the Databricks connection's extra parameters, + or ``sql_endpoint_name`` must be specified. + :param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path`` + must be provided as described above. + :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. + If not specified, it could be specified in the Databricks connection's extra parameters. + :param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers + on every request + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ + :param kwargs: Additional parameters internal to Databricks SQL Connector parameters + """ super().__init__(databricks_conn_id) self._sql_conn = None self._token: Optional[str] = None @@ -66,9 +73,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): self._sql_endpoint_name = sql_endpoint_name self.supports_autocommit = True self.session_config = session_configuration - self.metadata = metadata + self.http_headers = http_headers self.catalog = catalog self.schema = schema + self.additional_params = kwargs def _get_extra_config(self) -> Dict[str, Optional[Any]]: extra_params = copy(self.databricks_conn.extra_dejson) @@ -122,8 +130,13 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): self.host, self._http_path, self._token, + schema=self.schema, + catalog=self.catalog, session_configuration=self.session_config, + http_headers=self.http_headers, + _user_agent_entry=USER_AGENT_STRING, **self._get_extra_config(), + **self.additional_params, ) return self._sql_conn diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index a7d157932d..9e6298bc21 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -53,16 +53,18 @@ class DatabricksSqlOperator(BaseOperator): :param parameters: (optional) the parameters to render the SQL query with. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. - :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request - :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ - :param schema: An optional initial schema to use. Requires DBR version 9.0+ + :param client_parameters: Additional parameters internal to Databricks SQL Connector parameters + :param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers on every request. + (templated) + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ (templated) + :param schema: An optional initial schema to use. Requires DBR version 9.0+ (templated) :param output_path: optional string specifying the file to which write selected data. (templated) :param output_format: format of output data if ``output_path` is specified. Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``. :param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data. """ - template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 'metadata') + template_fields: Sequence[str] = ('sql', '_output_path', 'schema', 'catalog', 'http_headers') template_ext: Sequence[str] = ('.sql',) template_fields_renderers = {'sql': 'sql'} @@ -75,13 +77,14 @@ class DatabricksSqlOperator(BaseOperator): sql_endpoint_name: Optional[str] = None, parameters: Optional[Union[Mapping, Iterable]] = None, session_configuration=None, - metadata: Optional[List[Tuple[str, str]]] = None, + http_headers: Optional[List[Tuple[str, str]]] = None, catalog: Optional[str] = None, schema: Optional[str] = None, do_xcom_push: bool = False, output_path: Optional[str] = None, output_format: str = 'csv', csv_params: Optional[Dict[str, Any]] = None, + client_parameters: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: """Creates a new ``DatabricksSqlOperator``.""" @@ -96,9 +99,10 @@ class DatabricksSqlOperator(BaseOperator): self.parameters = parameters self.do_xcom_push = do_xcom_push self.session_config = session_configuration - self.metadata = metadata + self.http_headers = http_headers self.catalog = catalog self.schema = schema + self.client_parameters = client_parameters or {} def _get_hook(self) -> DatabricksSqlHook: return DatabricksSqlHook( @@ -106,9 +110,10 @@ class DatabricksSqlOperator(BaseOperator): http_path=self._http_path, session_configuration=self.session_config, sql_endpoint_name=self._sql_endpoint_name, - metadata=self.metadata, + http_headers=self.http_headers, catalog=self.catalog, schema=self.schema, + **self.client_parameters, ) def _format_output(self, schema, results): @@ -179,12 +184,16 @@ class DatabricksCopyIntoOperator(BaseOperator): If not specified, ``http_path`` must be provided as described above. :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. If not specified, it could be specified in the Databricks connection's extra parameters. - :param metadata: An optional list of (k, v) pairs that will be set as Http headers on every request + :param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers on every request + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ + :param client_parameters: Additional parameters internal to Databricks SQL Connector parameters :param files: optional list of files to import. Can't be specified together with ``pattern``. (templated) :param pattern: optional regex string to match file names to import. Can't be specified together with ``files``. :param expression_list: optional string that will be used in the ``SELECT`` expression. - :param credential: optional credential configuration for authentication against a specified location. + :param credential: optional credential configuration for authentication against a source location. + :param storage_credential: optional Unity Catalog storage credential for destination. :param encryption: optional encryption configuration for a specified location. :param format_options: optional dictionary with options specific for a given file format. :param force_copy: optional bool to control forcing of data import @@ -210,11 +219,15 @@ class DatabricksCopyIntoOperator(BaseOperator): http_path: Optional[str] = None, sql_endpoint_name: Optional[str] = None, session_configuration=None, - metadata: Optional[List[Tuple[str, str]]] = None, + http_headers: Optional[List[Tuple[str, str]]] = None, + client_parameters: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, files: Optional[List[str]] = None, pattern: Optional[str] = None, expression_list: Optional[str] = None, credential: Optional[Dict[str, str]] = None, + storage_credential: Optional[str] = None, encryption: Optional[Dict[str, str]] = None, format_options: Optional[Dict[str, str]] = None, force_copy: Optional[bool] = None, @@ -240,14 +253,18 @@ class DatabricksCopyIntoOperator(BaseOperator): self._sql_endpoint_name = sql_endpoint_name self.session_config = session_configuration self._table_name = table_name + self._catalog = catalog + self._schema = schema self._file_location = file_location self._expression_list = expression_list self._credential = credential + self._storage_credential = storage_credential self._encryption = encryption self._format_options = format_options self._copy_options = copy_options or {} self._validate = validate - self.metadata = metadata + self._http_headers = http_headers + self._client_parameters = client_parameters or {} if force_copy is not None: self._copy_options["force"] = 'true' if force_copy else 'false' @@ -257,7 +274,10 @@ class DatabricksCopyIntoOperator(BaseOperator): http_path=self._http_path, session_configuration=self.session_config, sql_endpoint_name=self._sql_endpoint_name, - metadata=self.metadata, + http_headers=self._http_headers, + catalog=self._catalog, + schema=self._schema, + **self._client_parameters, ) @staticmethod @@ -298,6 +318,9 @@ class DatabricksCopyIntoOperator(BaseOperator): files_or_pattern = f"FILES = {escaper.escape_item(self._files)}\n" format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options) + "\n" copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options) + "\n" + storage_cred = "" + if self._storage_credential: + storage_cred = f" WITH (CREDENTIAL {self._storage_credential})" validation = "" if self._validate is not None: if isinstance(self._validate, bool): @@ -310,9 +333,11 @@ class DatabricksCopyIntoOperator(BaseOperator): ) validation = f"VALIDATE {self._validate} ROWS\n" else: - raise AirflowException("Incorrect data type for validate parameter: " + type(self._validate)) + raise AirflowException( + "Incorrect data type for validate parameter: " + str(type(self._validate)) + ) # TODO: think on how to make sure that table_name and expression_list aren't used for SQL injection - sql = f"""COPY INTO {self._table_name} + sql = f"""COPY INTO {self._table_name}{storage_cred} FROM {location} FILEFORMAT = {self._file_format} {validation}{files_or_pattern}{format_options}{copy_options} diff --git a/docs/apache-airflow-providers-databricks/operators/copy_into.rst b/docs/apache-airflow-providers-databricks/operators/copy_into.rst index 1b73bac8d2..1d4ef07de2 100644 --- a/docs/apache-airflow-providers-databricks/operators/copy_into.rst +++ b/docs/apache-airflow-providers-databricks/operators/copy_into.rst @@ -49,25 +49,29 @@ Operator loads data from a specified location into a table using a configured en - Optional HTTP path for Databricks SQL endpoint or Databricks cluster. If not specified, it should be provided in Databricks connection, or the ``sql_endpoint_name`` parameter must be set. * - session_configuration: dict[str,str] - optional dict specifying Spark configuration parameters that will be set for the session. - * - metadata: list[tuple[str, str]] - - Optional list of (k, v) pairs that will be set as Http headers on every request - * - files: Optional[List[str]] + * - http_headers: list[tuple[str, str]] + - Optional list of (k, v) pairs that will be set as HTTP headers on every request + * - client_parameters: dict[str,str] + - optional additional parameters internal to Databricks SQL Connector parameters + * - files: list[str]] - optional list of files to import. Can't be specified together with ``pattern``. - * - pattern: Optional[str] + * - pattern: str - optional regex string to match file names to import. Can't be specified together with ``files``. - * - expression_list: Optional[str] + * - expression_list: str - optional string that will be used in the ``SELECT`` expression. - * - credential: Optional[Dict[str, str]] + * - credential: dict[str, str] - optional credential configuration for authentication against a specified location - * - encryption: Optional[Dict[str, str]] + * - encryption: dict[str, str] - optional encryption configuration for a specified location - * - format_options: Optional[Dict[str, str]] + * - storage_credential: str + - optional Unity Catalog storage credential name for the target table + * - format_options: dict[str, str] - optional dictionary with options specific for a given file format. - * - force_copy: Optional[bool] - - optional bool to control forcing of data import (could be also specified in ``copy_options``). - * - copy_options: Optional[Dict[str, str]] + * - force_copy: bool + - optional boolean parameter to control forcing of data import (could be also specified in ``copy_options``). + * - copy_options: dict[str, str] - optional dictionary of copy options. Right now only ``force`` option is supported. - * - validate: Optional[Union[bool, int]] + * - validate: union[bool, int]] - optional validation configuration. ``True`` forces validation of all rows, positive number - only N first rows. (requires Preview channel) Examples diff --git a/docs/apache-airflow-providers-databricks/operators/sql.rst b/docs/apache-airflow-providers-databricks/operators/sql.rst index 953bbae245..d0a1d6d337 100644 --- a/docs/apache-airflow-providers-databricks/operators/sql.rst +++ b/docs/apache-airflow-providers-databricks/operators/sql.rst @@ -51,8 +51,10 @@ Operator executes given SQL queries against configured endpoint. There are 3 wa - Optional parameters that will be used to substitute variable(s) in SQL query. * - session_configuration: dict[str,str] - optional dict specifying Spark configuration parameters that will be set for the session. - * - metadata: list[tuple[str, str]] - - Optional list of (k, v) pairs that will be set as Http headers on every request + * - http_headers: list[tuple[str, str]] + - Optional list of (k, v) pairs that will be set as HTTP headers on every request + * - client_parameters: dict[str,str] + - optional additional parameters internal to Databricks SQL Connector parameters * - catalog: str - Optional initial catalog to use. Requires DBR version 9.0+ * - schema: str @@ -63,7 +65,7 @@ Operator executes given SQL queries against configured endpoint. There are 3 wa - Name of the format which will be used to write results. Supported values are (case-insensitive): ``JSON`` (array of JSON objects), ``JSONL`` (each row as JSON object on a separate line), ``CSV`` (default). * - csv_params: dict[str, any] - Optional dictionary with parameters to customize Python CSV writer. - * - do_xcom_push: boolean + * - do_xcom_push: bool - whether we should push query results (last query if multiple queries are provided) to xcom. Default: false Examples diff --git a/setup.py b/setup.py index e305aed6da..aa538bfe69 100644 --- a/setup.py +++ b/setup.py @@ -264,7 +264,7 @@ dask = [ ] databricks = [ 'requests>=2.26.0, <3', - 'databricks-sql-connector>=1.0.2, <2.0.0', + 'databricks-sql-connector>=2.0.0, <3.0.0', ] datadog = [ 'datadog>=0.14.0', diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index afa25932a9..783fa520a7 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -57,7 +57,7 @@ class TestDatabricksSqlOperator(unittest.TestCase): http_path=None, session_configuration=None, sql_endpoint_name=None, - metadata=None, + http_headers=None, catalog=None, schema=None, ) @@ -88,7 +88,7 @@ class TestDatabricksSqlOperator(unittest.TestCase): http_path=None, session_configuration=None, sql_endpoint_name=None, - metadata=None, + http_headers=None, catalog=None, schema=None, ) @@ -153,6 +153,25 @@ COPY_OPTIONS ('force' = 'true') == f"""COPY INTO test FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') )) FILEFORMAT = CSV +""".strip() + ) + + def test_copy_with_target_credential(self): + expression = "col1, col2" + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format='CSV', + table_name='test', + task_id=TASK_ID, + expression_list=expression, + storage_credential='abc', + credential={'AZURE_SAS_TOKEN': 'abc'}, + ) + assert ( + op._create_sql_query() + == f"""COPY INTO test WITH (CREDENTIAL abc) +FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') )) +FILEFORMAT = CSV """.strip() )
