This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new db5375bea7 Fixing the behaviours of SQL Hooks and Operators finally
(#27912)
db5375bea7 is described below
commit db5375bea7a0564c12f56c91e1c8c7b6c049698c
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sat Nov 26 10:26:15 2022 +0100
Fixing the behaviours of SQL Hooks and Operators finally (#27912)
This is the more bold fix to common.sql and related providers.
It implements more comprehensive tests and describes the intended
behaviours in much more explicit way, also it simplifies the notion
of "scalar" behaviour by introducing a new method to check on when
scalar return is expected.
Comprehensive suite of tests have been added (the original tests
were converted to Pytest test and parameterized was used to
test all the combinations of parameters and returned values.
---
airflow/providers/common/sql/CHANGELOG.rst | 4 +-
airflow/providers/common/sql/hooks/sql.py | 103 ++++-
airflow/providers/common/sql/operators/sql.py | 76 ++--
.../providers/databricks/hooks/databricks_sql.py | 36 +-
.../databricks/operators/databricks_sql.py | 28 +-
airflow/providers/exasol/hooks/exasol.py | 5 +-
airflow/providers/snowflake/hooks/snowflake.py | 5 +-
tests/providers/common/sql/hooks/test_sql.py | 226 +++++++++++
tests/providers/common/sql/operators/test_sql.py | 2 +
.../common/sql/operators/test_sql_execute.py | 276 +++++++++++++
.../databricks/hooks/test_databricks_sql.py | 227 +++++++++--
.../databricks/operators/test_databricks_copy.py | 230 +++++++++++
.../databricks/operators/test_databricks_sql.py | 428 ++++++++++-----------
tests/providers/jdbc/operators/test_jdbc.py | 2 +
14 files changed, 1286 insertions(+), 362 deletions(-)
diff --git a/airflow/providers/common/sql/CHANGELOG.rst
b/airflow/providers/common/sql/CHANGELOG.rst
index 7f48cf5573..b436b9bc43 100644
--- a/airflow/providers/common/sql/CHANGELOG.rst
+++ b/airflow/providers/common/sql/CHANGELOG.rst
@@ -32,7 +32,9 @@ This release fixes a few errors that were introduced in
common.sql operator whil
* ``_process_output`` method in ``SQLExecuteQueryOperator`` has now consistent
semantics and typing, it
can also modify the returned (and stored in XCom) values in the operators
that derive from the
``SQLExecuteQueryOperator``).
-* last description of the cursor whether to return scalar values are now
stored in DBApiHook
+* descriptions of all returned results are stored as descriptions property in
the DBApiHook
+* last description of the cursor whether to return single query results values
are now exposed in
+ DBApiHook via last_description property.
Lack of consistency in the operator caused ``1.3.0`` to be yanked - the
``1.3.0`` should not be used - if
you have ``1.3.0`` installed, upgrade to ``1.3.1``.
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index df808430fd..06dbdb2f57 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -30,6 +30,33 @@ from airflow.hooks.base import BaseHook
from airflow.version import version
+def return_single_query_results(sql: str | Iterable[str], return_last: bool,
split_statements: bool):
+ """
+ Determines when results of single query only should be returned.
+
+ For compatibility reasons, the behaviour of the DBAPIHook is somewhat
confusing.
+ In cases, when multiple queries are run, the return values will be an
iterable (list) of results
+ - one for each query. However, in certain cases, when single query is run
- the results will be just
+ the results of that single query without wrapping the results in a list.
+
+ The cases when single query results are returned without wrapping them in
a list are when:
+
+ a) sql is string and last_statement is True (regardless what
split_statement value is)
+ b) sql is string and split_statement is False
+
+ In all other cases, the results are wrapped in a list, even if there is
only one statement to process:
+
+ a) always when sql is an iterable of string statements (regardless what
last_statement value is)
+ b) when sql is string, split_statement is True and last_statement is False
+
+ :param sql: sql to run (either string or list of strings)
+ :param return_last: whether last statement output should only be returned
+ :param split_statements: whether to split string statements.
+ :return: True if the hook should return single query results
+ """
+ return isinstance(sql, str) and (return_last or not split_statements)
+
+
def fetch_all_handler(cursor) -> list[tuple] | None:
"""Handler for DbApiHook.run() to return results"""
if cursor.description is not None:
@@ -114,8 +141,7 @@ class DbApiHook(BaseForDbApiHook):
# Hook deriving from the DBApiHook to still have access to the field
in its constructor
self.__schema = schema
self.log_sql = log_sql
- self.scalar_return_last = False
- self.last_description: Sequence[Sequence] | None = None
+ self.descriptions: list[Sequence[Sequence] | None] = []
def get_conn(self):
"""Returns a connection object"""
@@ -222,6 +248,12 @@ class DbApiHook(BaseForDbApiHook):
statements: list[str] = list(filter(None, splits))
return statements
+ @property
+ def last_description(self) -> Sequence[Sequence] | None:
+ if not self.descriptions:
+ return None
+ return self.descriptions[-1]
+
def run(
self,
sql: str | Iterable[str],
@@ -234,7 +266,42 @@ class DbApiHook(BaseForDbApiHook):
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
- sequentially
+ sequentially.
+
+ The method will return either single query results (typically list of
rows) or list of those results
+ where each element in the list are results of one of the queries
(typically list of list of rows :D)
+
+ For compatibility reasons, the behaviour of the DBAPIHook is somewhat
confusing.
+ In cases, when multiple queries are run, the return values will be an
iterable (list) of results
+ - one for each query. However, in certain cases, when single query is
run - the results will be just
+ the results of that query without wrapping the results in a list.
+
+ The cases when single query results are returned without wrapping them
in a list are when:
+
+ a) sql is string and last_statement is True (regardless what
split_statement value is)
+ b) sql is string and split_statement is False
+
+ In all other cases, the results are wrapped in a list, even if there
is only one statement to process:
+
+ a) always when sql is an iterable of string statements (regardless
what last_statement value is)
+ b) when sql is string, split_statement is True and last_statement is
False
+
+
+ In any of those cases, however you can access the following properties
of the Hook after running it:
+
+ * descriptions - has an array of cursor descriptions - each
statement executed contain the list
+ of descriptions executed. If ``return_last`` is used, this is
always a one-element array
+ * last_description - description of the last statement executed
+
+ Note that return value from the hook will ONLY be actually returned
when handler is provided. Setting
+ the ``handler`` to None, results in this method returning None.
+
+ Handler is a way to process the rows from cursor (Iterator) into a
value that is suitable to be
+ returned to XCom and generally fit in memory. As an optimization,
handler is usually not executed
+ by the SQLExecuteQuery operator if `do_xcom_push` is not specified.
+
+ You can use pre-defined handles (`fetch_all_handler``,
''fetch_one_handler``) or implement your
+ own handler.
:param sql: the sql statement to be executed (str) or a list of
sql statements to execute
@@ -246,31 +313,38 @@ class DbApiHook(BaseForDbApiHook):
:param return_last: Whether to return result for only last statement
or for all after split
:return: return only result of the ALL SQL expressions if handler was
provided.
"""
- self.scalar_return_last = isinstance(sql, str) and return_last
+ self.descriptions = []
+
if isinstance(sql, str):
if split_statements:
- sql = self.split_sql_string(sql)
+ sql_list: Iterable[str] = self.split_sql_string(sql)
else:
- sql = [sql]
+ sql_list = [sql] if sql.strip() else []
+ else:
+ sql_list = sql
- if sql:
- self.log.debug("Executing following statements against DB: %s",
list(sql))
+ if sql_list:
+ self.log.debug("Executing following statements against DB: %s",
sql_list)
else:
raise ValueError("List of SQL statements is empty")
-
+ _last_result = None
with closing(self.get_conn()) as conn:
if self.supports_autocommit:
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
results = []
- for sql_statement in sql:
+ for sql_statement in sql_list:
self._run_command(cur, sql_statement, parameters)
if handler is not None:
result = handler(cur)
- results.append(result)
- self.last_description = cur.description
+ if return_single_query_results(sql, return_last,
split_statements):
+ _last_result = result
+ _last_description = cur.description
+ else:
+ results.append(result)
+ self.descriptions.append(cur.description)
# If autocommit was set to False or db does not support
autocommit, we do a manual commit.
if not self.get_autocommit(conn):
@@ -278,8 +352,9 @@ class DbApiHook(BaseForDbApiHook):
if handler is None:
return None
- elif self.scalar_return_last:
- return results[-1]
+ if return_single_query_results(sql, return_last, split_statements):
+ self.descriptions = [_last_description]
+ return _last_result
else:
return results
diff --git a/airflow/providers/common/sql/operators/sql.py
b/airflow/providers/common/sql/operators/sql.py
index 314af43003..09034b104c 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -19,14 +19,13 @@ from __future__ import annotations
import ast
import re
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn,
Sequence, SupportsAbs, overload
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn,
Sequence, SupportsAbs
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, SkipMixin
-from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
-from airflow.typing_compat import Literal
+from airflow.providers.common.sql.hooks.sql import DbApiHook,
fetch_all_handler, return_single_query_results
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -190,11 +189,17 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
Executes SQL code in a specific database
:param sql: the SQL code or string pointing to a template file to be
executed (templated).
File must have a '.sql' extensions.
+
+ When implementing a specific Operator, you can also implement
`_process_output` method in the
+ hook to perform additional processing of values returned by the DB Hook of
yours. For example, you
+ can join description retrieved from the cursors of your statements with
returned values, or save
+ the output of your operator to a file.
+
:param autocommit: (optional) if True, each command is automatically
committed (default: False).
:param parameters: (optional) the parameters to render the SQL query with.
:param handler: (optional) the function that will be applied to the cursor
(default: fetch_all_handler).
:param split_statements: (optional) if split single SQL string into
statements (default: False).
- :param return_last: (optional) if return the result of only last statement
(default: True).
+ :param return_last: (optional) return the result of only last statement
(default: True).
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -225,54 +230,42 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
self.split_statements = split_statements
self.return_last = return_last
- @overload
- def _process_output(
- self, results: Any, description: Sequence[Sequence] | None,
scalar_results: Literal[True]
- ) -> Any:
- pass
-
- @overload
- def _process_output(
- self, results: list[Any], description: Sequence[Sequence] | None,
scalar_results: Literal[False]
- ) -> Any:
- pass
-
- def _process_output(
- self, results: Any | list[Any], description: Sequence[Sequence] |
None, scalar_results: bool
- ) -> Any:
+ def _process_output(self, results: list[Any], descriptions:
list[Sequence[Sequence] | None]) -> list[Any]:
"""
- Can be overridden by the subclass in case some extra processing is
needed.
+ Processes output before it is returned by the operator.
+
+ It can be overridden by the subclass in case some extra processing is
needed. Note that unlike
+ DBApiHook return values returned - the results passed and returned by
``_process_output`` should
+ always be lists of results - each element of the list is a result from
a single SQL statement
+ (typically this will be list of Rows). You have to make sure that this
is the same for returned
+ values = there should be one element in the list for each statement
executed by the hook..
+
The "process_output" method can override the returned output -
augmenting or processing the
output as needed - the output returned will be returned as execute
return value and if
- do_xcom_push is set to True, it will be set as XCom returned
+ do_xcom_push is set to True, it will be set as XCom returned.
:param results: results in the form of list of rows.
- :param description: as returned by ``cur.description`` in the Python
DBAPI
- :param scalar_results: True if result is single scalar value rather
than list of rows
+ :param descriptions: list of descriptions returned by
``cur.description`` in the Python DBAPI
"""
return results
def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
- if self.do_xcom_push:
- output = hook.run(
- sql=self.sql,
- autocommit=self.autocommit,
- parameters=self.parameters,
- handler=self.handler,
- split_statements=self.split_statements,
- return_last=self.return_last,
- )
- else:
- output = hook.run(
- sql=self.sql,
- autocommit=self.autocommit,
- parameters=self.parameters,
- split_statements=self.split_statements,
- )
-
- return self._process_output(output, hook.last_description,
hook.scalar_return_last)
+ output = hook.run(
+ sql=self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ handler=self.handler if self.do_xcom_push else None,
+ split_statements=self.split_statements,
+ return_last=self.return_last,
+ )
+ if return_single_query_results(self.sql, self.return_last,
self.split_statements):
+ # For simplicity, we pass always list as input to _process_output,
regardless if
+ # single query results are going to be returned, and we return the
first element
+ # of the list in this case from the (always) list returned by
_process_output
+ return self._process_output([output], hook.descriptions)[-1]
+ return self._process_output(output, hook.descriptions)
def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
@@ -284,6 +277,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
"""
Performs one or more of the templated checks in the column_checks
dictionary.
Checks are performed on a per-column basis specified by the column_mapping.
+
Each check can take one or more of the following options:
- equal_to: an exact value to equal, cannot be used with other comparison
options
- greater_than: value that result should be strictly greater than
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py
b/airflow/providers/databricks/hooks/databricks_sql.py
index f042435943..245128aca9 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -24,7 +24,7 @@ from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
from airflow.exceptions import AirflowException
-from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
from airflow.providers.databricks.hooks.databricks_base import
BaseDatabricksHook
LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
@@ -151,49 +151,57 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
- sequentially
+ sequentially.
:param sql: the sql statement to be executed (str) or a list of
sql statements to execute
:param autocommit: What to set the connection's autocommit setting to
- before executing the query.
+ before executing the query. Note that currently there is no commit
functionality
+ in Databricks SQL so this flag has no effect.
+
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of
each statement.
:param split_statements: Whether to split a single SQL string into
statements and run separately
:param return_last: Whether to return result for only last statement
or for all after split
- :return: return only result of the LAST SQL expression if handler was
provided.
+ :return: return only result of the LAST SQL expression if handler was
provided unless return_last
+ is set to False.
"""
- self.scalar_return_last = isinstance(sql, str) and return_last
+ self.descriptions = []
if isinstance(sql, str):
if split_statements:
- sql = self.split_sql_string(sql)
+ sql_list = [self.strip_sql_string(s) for s in
self.split_sql_string(sql)]
else:
- sql = [self.strip_sql_string(sql)]
+ sql_list = [self.strip_sql_string(sql)]
+ else:
+ sql_list = [self.strip_sql_string(s) for s in sql]
- if sql:
- self.log.debug("Executing following statements against Databricks
DB: %s", list(sql))
+ if sql_list:
+ self.log.debug("Executing following statements against Databricks
DB: %s", sql_list)
else:
raise ValueError("List of SQL statements is empty")
results = []
- for sql_statement in sql:
+ for sql_statement in sql_list:
# when using AAD tokens, it could expire if previous query run
longer than token lifetime
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)
-
if handler is not None:
result = handler(cur)
- results.append(result)
- self.last_description = cur.description
+ if return_single_query_results(sql, return_last,
split_statements):
+ results = [result]
+ self.descriptions = [cur.description]
+ else:
+ results.append(result)
+ self.descriptions.append(cur.description)
self._sql_conn = None
if handler is None:
return None
- elif self.scalar_return_last:
+ if return_single_query_results(sql, return_last, split_statements):
return results[-1]
else:
return results
diff --git a/airflow/providers/databricks/operators/databricks_sql.py
b/airflow/providers/databricks/operators/databricks_sql.py
index 379b0fd2c9..178afc8d98 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -120,21 +120,17 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
}
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)
- def _process_output(
- self, results: Any | list[Any], description: Sequence[Sequence] |
None, scalar_results: bool
- ) -> Any:
+ def _process_output(self, results: list[Any], descriptions:
list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
- return description, results
+ return list(zip(descriptions, results))
if not self._output_format:
raise AirflowException("Output format should be specified!")
- if description is None:
- self.log.warning("Description of the cursor is missing. Will not
process the output")
- return description, results
- field_names = [field[0] for field in description]
- if scalar_results:
- list_results: list[Any] = [results]
- else:
- list_results = results
+ # Output to a file only the result of last query
+ last_description = descriptions[-1]
+ last_results = results[-1]
+ if last_description is None:
+ raise AirflowException("There is missing description present for
the output file. .")
+ field_names = [field[0] for field in last_description]
if self._output_format.lower() == "csv":
with open(self._output_path, "w", newline="") as file:
if self._csv_params:
@@ -147,19 +143,19 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
writer = csv.DictWriter(file, fieldnames=field_names,
**csv_params)
if write_header:
writer.writeheader()
- for row in list_results:
+ for row in last_results:
writer.writerow(row.asDict())
elif self._output_format.lower() == "json":
with open(self._output_path, "w") as file:
- file.write(json.dumps([row.asDict() for row in list_results]))
+ file.write(json.dumps([row.asDict() for row in last_results]))
elif self._output_format.lower() == "jsonl":
with open(self._output_path, "w") as file:
- for row in list_results:
+ for row in last_results:
file.write(json.dumps(row.asDict()))
file.write("\n")
else:
raise AirflowException(f"Unsupported output format:
'{self._output_format}'")
- return description, results
+ return list(zip(descriptions, results))
COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT",
"BINARYFILE"]
diff --git a/airflow/providers/exasol/hooks/exasol.py
b/airflow/providers/exasol/hooks/exasol.py
index 49289df37a..3dc6b81973 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -24,7 +24,7 @@ import pandas as pd
import pyexasol
from pyexasol import ExaConnection
-from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
class ExasolHook(DbApiHook):
@@ -157,7 +157,6 @@ class ExasolHook(DbApiHook):
:param return_last: Whether to return result for only last statement
or for all after split
:return: return only result of the LAST SQL expression if handler was
provided.
"""
- self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
sql = self.split_sql_string(sql)
@@ -187,7 +186,7 @@ class ExasolHook(DbApiHook):
if handler is None:
return None
- elif self.scalar_return_last:
+ elif return_single_query_results(sql, return_last, split_statements):
return results[-1]
else:
return results
diff --git a/airflow/providers/snowflake/hooks/snowflake.py
b/airflow/providers/snowflake/hooks/snowflake.py
index e525efe763..c08247be21 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -32,7 +32,7 @@ from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
from airflow import AirflowException
-from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
from airflow.utils.strings import to_boolean
@@ -350,7 +350,6 @@ class SnowflakeHook(DbApiHook):
"""
self.query_ids = []
- self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
split_statements_tuple =
util_text.split_statements(StringIO(sql))
@@ -387,7 +386,7 @@ class SnowflakeHook(DbApiHook):
if handler is None:
return None
- elif self.scalar_return_last:
+ elif return_single_query_results(sql, return_last, split_statements):
return results[-1]
else:
return results
diff --git a/tests/providers/common/sql/hooks/test_sql.py
b/tests/providers/common/sql/hooks/test_sql.py
new file mode 100644
index 0000000000..72684e5e00
--- /dev/null
+++ b/tests/providers/common/sql/hooks/test_sql.py
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.utils.session import provide_session
+
+TASK_ID = "sql-operator"
+HOST = "host"
+DEFAULT_CONN_ID = "sqlite_default"
+PASSWORD = "password"
+
+
+class DBApiHookForTests(DbApiHook):
+ conn_name_attr = "conn_id"
+ get_conn = MagicMock(name="conn")
+
+
+@provide_session
[email protected](autouse=True)
+def create_connection(session):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.login = None
+ conn.password = PASSWORD
+ conn.extra = None
+ session.commit()
+
+
+def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
+ return [(field,) for field in fields]
+
+
+index = 0
+
+
[email protected](
+ "return_last, split_statements, sql, cursor_calls,"
+ "cursor_descriptions, cursor_results, hook_descriptions, hook_results, ",
+ [
+ pytest.param(
+ True,
+ False,
+ "select * from test.test",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and no split statements set on single
query in string",
+ ),
+ pytest.param(
+ False,
+ False,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last not set and no split statements set on single
query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and split statements set on single query
in string",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set and split statements set on single
query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test;", "select * from test.test2;"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set and split statements set on multiple
queries in string",
+ ), # Failing
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test;", "select * from test.test2;"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set and split statements set on multiple
queries in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last set on single query in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set on single query in list",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set set on multiple queries in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set on multiple queries not set",
+ ),
+ ],
+)
+def test_query(
+ return_last,
+ split_statements,
+ sql,
+ cursor_calls,
+ cursor_descriptions,
+ cursor_results,
+ hook_descriptions,
+ hook_results,
+):
+ modified_descriptions = [
+ get_cursor_descriptions(cursor_description) for cursor_description in
cursor_descriptions
+ ]
+ dbapi_hook = DBApiHookForTests()
+ dbapi_hook.get_conn.return_value.cursor.return_value.rowcount = 2
+ dbapi_hook.get_conn.return_value.cursor.return_value._description_index = 0
+
+ def mock_execute(*args, **kwargs):
+ # the run method accesses description property directly, and we need
to modify it after
+ # every execute, to make sure that different descriptions are
returned. I could not find easier
+ # method with mocking
+ dbapi_hook.get_conn.return_value.cursor.return_value.description =
modified_descriptions[
+
dbapi_hook.get_conn.return_value.cursor.return_value._description_index
+ ]
+
dbapi_hook.get_conn.return_value.cursor.return_value._description_index += 1
+
+ dbapi_hook.get_conn.return_value.cursor.return_value.execute = mock_execute
+ dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect
= cursor_results
+ results = dbapi_hook.run(
+ sql=sql, handler=fetch_all_handler, return_last=return_last,
split_statements=split_statements
+ )
+
+ assert dbapi_hook.descriptions == hook_descriptions
+ assert dbapi_hook.last_description == hook_descriptions[-1]
+ assert results == hook_results
+
+ dbapi_hook.get_conn.return_value.cursor.return_value.close.assert_called()
+
+
[email protected](
+ "empty_statement",
+ [
+ pytest.param([], id="Empty list"),
+ pytest.param("", id="Empty string"),
+ pytest.param("\n", id="Only EOL"),
+ ],
+)
+def test_no_query(empty_statement):
+ dbapi_hook = DBApiHookForTests()
+ dbapi_hook.get_conn.return_value.cursor.rowcount = 0
+ with pytest.raises(ValueError) as err:
+ dbapi_hook.run(sql=empty_statement)
+ assert err.value.args[0] == "List of SQL statements is empty"
diff --git a/tests/providers/common/sql/operators/test_sql.py
b/tests/providers/common/sql/operators/test_sql.py
index 1770ed8f5e..216ac79280 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -88,6 +88,8 @@ class TestSQLExecuteQueryOperator:
autocommit=False,
parameters=None,
split_statements=False,
+ handler=None,
+ return_last=True,
)
diff --git a/tests/providers/common/sql/operators/test_sql_execute.py
b/tests/providers/common/sql/operators/test_sql_execute.py
new file mode 100644
index 0000000000..0459472d3a
--- /dev/null
+++ b/tests/providers/common/sql/operators/test_sql_execute.py
@@ -0,0 +1,276 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, NamedTuple, Sequence
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
+
+DATE = "2017-04-20"
+TASK_ID = "sql-operator"
+
+
+class Row(NamedTuple):
+ id: str
+ value: str
+
+
+class Row2(NamedTuple):
+ id2: str
+ value2: str
+
+
[email protected](
+ "sql, return_last, split_statement, hook_results, hook_descriptions,
expected_results",
+ [
+ pytest.param(
+ "select * from dummy",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statement, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy;select * from dummy2",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Multiple SQL statements, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ False,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statements, no return_last (doesn't
matter), no split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ True,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statements, return_last (doesn't matter),
no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy"],
+ False,
+ False,
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ [[("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ id="Non-Scalar: Single SQL statements in list, no return_last, no
split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ False,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, no return_last
(no matter), no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ True,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, return_last (no
matter), no split statement",
+ ),
+ ],
+)
+def test_exec_success(sql, return_last, split_statement, hook_results,
hook_descriptions, expected_results):
+ """
+ Test the execute function in case where SQL query was successful.
+ """
+
+ class SQLExecuteQueryOperatorForTest(SQLExecuteQueryOperator):
+ _mock_db_api_hook = MagicMock()
+
+ def get_db_hook(self):
+ return self._mock_db_api_hook
+
+ op = SQLExecuteQueryOperatorForTest(
+ task_id=TASK_ID,
+ sql=sql,
+ do_xcom_push=True,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
+
+ op._mock_db_api_hook.run.return_value = hook_results
+ op._mock_db_api_hook.descriptions = hook_descriptions
+
+ execute_results = op.execute(None)
+
+ assert execute_results == expected_results
+ op._mock_db_api_hook.run.assert_called_once_with(
+ sql=sql,
+ parameters=None,
+ handler=fetch_all_handler,
+ autocommit=False,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
+
+
[email protected](
+ "sql, return_last, split_statement, hook_results, hook_descriptions,
expected_results",
+ [
+ pytest.param(
+ "select * from dummy",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statement, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy;select * from dummy2",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Multiple SQL statements, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ False,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statements, no return_last (doesn't
matter), no split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ True,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statements, return_last (doesn't matter),
no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy"],
+ False,
+ False,
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ [[("id",), ("value",)]],
+ [([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")])],
+ id="Non-Scalar: Single SQL statements in list, no return_last, no
split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ False,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ ([("id2",), ("value2",)], [Row2(id2=1, value2="value1"),
Row2(id2=2, value2="value2")]),
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, no return_last
(no matter), no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ True,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ ([("id2",), ("value2",)], [Row2(id2=1, value2="value1"),
Row2(id2=2, value2="value2")]),
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, return_last (no
matter), no split statement",
+ ),
+ ],
+)
+def test_exec_success_with_process_output(
+ sql, return_last, split_statement, hook_results, hook_descriptions,
expected_results
+):
+ """
+ Test the execute function in case where SQL query was successful.
+ """
+
+ class
SQLExecuteQueryOperatorForTestWithProcessOutput(SQLExecuteQueryOperator):
+ _mock_db_api_hook = MagicMock()
+
+ def get_db_hook(self):
+ return self._mock_db_api_hook
+
+ def _process_output(
+ self, results: list[Any], descriptions: list[Sequence[Sequence] |
None]
+ ) -> list[Any]:
+ return list(zip(descriptions, results))
+
+ op = SQLExecuteQueryOperatorForTestWithProcessOutput(
+ task_id=TASK_ID,
+ sql=sql,
+ do_xcom_push=True,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
+
+ op._mock_db_api_hook.run.return_value = hook_results
+ op._mock_db_api_hook.descriptions = hook_descriptions
+
+ execute_results = op.execute(None)
+
+ assert execute_results == expected_results
+ op._mock_db_api_hook.run.assert_called_once_with(
+ sql=sql,
+ parameters=None,
+ handler=fetch_all_handler,
+ autocommit=False,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py
b/tests/providers/databricks/hooks/test_databricks_sql.py
index 673ac0bddd..ecc0385278 100644
--- a/tests/providers/databricks/hooks/test_databricks_sql.py
+++ b/tests/providers/databricks/hooks/test_databricks_sql.py
@@ -19,6 +19,7 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import patch
import pytest
@@ -34,25 +35,156 @@ HOST_WITH_SCHEME = "https://xx.cloud.databricks.com"
TOKEN = "token"
-class TestDatabricksSqlHookQueryByName:
- """
- Tests for DatabricksHook.
- """
-
- @provide_session
- def setup_method(self, method, session=None):
- conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
- conn.host = HOST
- conn.login = None
- conn.password = TOKEN
- conn.extra = None
- session.commit()
-
- self.hook = DatabricksSqlHook(sql_endpoint_name="Test")
-
-
@mock.patch("airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook.get_conn")
- @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
- def test_query(self, mock_requests, mock_conn):
+@provide_session
[email protected](autouse=True)
+def create_connection(session):
+ conn = session.query(Connection).filter(Connection.conn_id ==
DEFAULT_CONN_ID).first()
+ conn.host = HOST
+ conn.login = None
+ conn.password = TOKEN
+ conn.extra = None
+ session.commit()
+
+
[email protected]
+def databricks_hook():
+ return DatabricksSqlHook(sql_endpoint_name="Test")
+
+
+def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
+ return [(field,) for field in fields]
+
+
[email protected](
+ "return_last, split_statements, sql, cursor_calls,"
+ "cursor_descriptions, cursor_results, hook_descriptions, hook_results, ",
+ [
+ pytest.param(
+ True,
+ False,
+ "select * from test.test",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and no split statements set on single
query in string",
+ ),
+ pytest.param(
+ False,
+ False,
+ "select * from test.test;",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last not set and no split statements set on single
query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and split statements set on single query
in string",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set and split statements set on single
query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set and split statements set on multiple
queries in string",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set and split statements set on multiple
queries in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last set on single query in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set on single query in list",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set set on multiple queries in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set on multiple queries not set",
+ ),
+ ],
+)
+def test_query(
+ databricks_hook,
+ return_last,
+ split_statements,
+ sql,
+ cursor_calls,
+ cursor_descriptions,
+ cursor_results,
+ hook_descriptions,
+ hook_results,
+):
+ with patch(
+
"airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook.get_conn"
+ ) as mock_conn,
patch("airflow.providers.databricks.hooks.databricks_base.requests") as
mock_requests:
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = {
"endpoints": [
@@ -68,26 +200,41 @@ class TestDatabricksSqlHookQueryByName:
}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock
-
- test_fields = ["id", "value"]
- test_description = [(field,) for field in test_fields]
-
- conn = mock_conn.return_value
- cur = mock.MagicMock(rowcount=0, description=test_description)
- cur.fetchall.return_value = []
- conn.cursor.return_value = cur
-
- query = "select * from test.test;"
- results = self.hook.run(sql=query, handler=fetch_all_handler)
-
- assert self.hook.last_description == test_description
- assert results == []
-
- cur.execute.assert_has_calls([mock.call(q) for q in [query]])
+ connections = []
+ cursors = []
+ for index in range(len(cursor_descriptions)):
+ conn = mock.MagicMock()
+ cur = mock.MagicMock(
+ rowcount=len(cursor_results[index]),
+
description=get_cursor_descriptions(cursor_descriptions[index]),
+ )
+ cur.fetchall.return_value = cursor_results[index]
+ conn.cursor.return_value = cur
+ cursors.append(cur)
+ connections.append(conn)
+ mock_conn.side_effect = connections
+ results = databricks_hook.run(
+ sql=sql, handler=fetch_all_handler, return_last=return_last,
split_statements=split_statements
+ )
+
+ assert databricks_hook.descriptions == hook_descriptions
+ assert databricks_hook.last_description == hook_descriptions[-1]
+ assert results == hook_results
+
+ for index, cur in enumerate(cursors):
+ cur.execute.assert_has_calls([mock.call(cursor_calls[index])])
cur.close.assert_called()
- def test_no_query(self):
- for empty_statement in ([], "", "\n"):
- with pytest.raises(ValueError) as err:
- self.hook.run(sql=empty_statement)
- assert err.value.args[0] == "List of SQL statements is empty"
+
[email protected](
+ "empty_statement",
+ [
+ pytest.param([], id="Empty list"),
+ pytest.param("", id="Empty string"),
+ pytest.param("\n", id="Only EOL"),
+ ],
+)
+def test_no_query(databricks_hook, empty_statement):
+ with pytest.raises(ValueError) as err:
+ databricks_hook.run(sql=empty_statement)
+ assert err.value.args[0] == "List of SQL statements is empty"
diff --git a/tests/providers/databricks/operators/test_databricks_copy.py
b/tests/providers/databricks/operators/test_databricks_copy.py
new file mode 100644
index 0000000000..510e0a6c2f
--- /dev/null
+++ b/tests/providers/databricks/operators/test_databricks_copy.py
@@ -0,0 +1,230 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow import AirflowException
+from airflow.providers.databricks.operators.databricks_sql import
DatabricksCopyIntoOperator
+
+DATE = "2017-04-20"
+TASK_ID = "databricks-sql-operator"
+DEFAULT_CONN_ID = "databricks_default"
+COPY_FILE_LOCATION = "s3://my-bucket/jsonData"
+
+
+def test_copy_with_files():
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ files=["file1", "file2", "file3"],
+ format_options={"dateFormat": "yyyy-MM-dd"},
+ task_id=TASK_ID,
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM '{COPY_FILE_LOCATION}'
+FILEFORMAT = JSON
+FILES = ('file1','file2','file3')
+FORMAT_OPTIONS ('dateFormat' = 'yyyy-MM-dd')
+""".strip()
+ )
+
+
+def test_copy_with_expression():
+ expression = "col1, col2"
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="CSV",
+ table_name="test",
+ task_id=TASK_ID,
+ pattern="folder1/file_[a-g].csv",
+ expression_list=expression,
+ format_options={"header": "true"},
+ force_copy=True,
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}')
+FILEFORMAT = CSV
+PATTERN = 'folder1/file_[a-g].csv'
+FORMAT_OPTIONS ('header' = 'true')
+COPY_OPTIONS ('force' = 'true')
+""".strip()
+ )
+
+
+def test_copy_with_credential():
+ expression = "col1, col2"
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="CSV",
+ table_name="test",
+ task_id=TASK_ID,
+ expression_list=expression,
+ credential={"AZURE_SAS_TOKEN": "abc"},
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL
(AZURE_SAS_TOKEN = 'abc') ))
+FILEFORMAT = CSV
+""".strip()
+ )
+
+
+def test_copy_with_target_credential():
+ expression = "col1, col2"
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="CSV",
+ table_name="test",
+ task_id=TASK_ID,
+ expression_list=expression,
+ storage_credential="abc",
+ credential={"AZURE_SAS_TOKEN": "abc"},
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test WITH (CREDENTIAL abc)
+FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL
(AZURE_SAS_TOKEN = 'abc') ))
+FILEFORMAT = CSV
+""".strip()
+ )
+
+
+def test_copy_with_encryption():
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="CSV",
+ table_name="test",
+ task_id=TASK_ID,
+ encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM '{COPY_FILE_LOCATION}' WITH ( ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY
= 'abc'))
+FILEFORMAT = CSV
+""".strip()
+ )
+
+
+def test_copy_with_encryption_and_credential():
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="CSV",
+ table_name="test",
+ task_id=TASK_ID,
+ encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
+ credential={"AZURE_SAS_TOKEN": "abc"},
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """
+ """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
+FILEFORMAT = CSV
+""".strip()
+ )
+
+
+def test_copy_with_validate_all():
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ task_id=TASK_ID,
+ validate=True,
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM '{COPY_FILE_LOCATION}'
+FILEFORMAT = JSON
+VALIDATE ALL
+""".strip()
+ )
+
+
+def test_copy_with_validate_N_rows():
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ task_id=TASK_ID,
+ validate=10,
+ )
+ assert (
+ op._create_sql_query()
+ == f"""COPY INTO test
+FROM '{COPY_FILE_LOCATION}'
+FILEFORMAT = JSON
+VALIDATE 10 ROWS
+""".strip()
+ )
+
+
+def test_incorrect_params_files_patterns():
+ exception_message = "Only one of 'pattern' or 'files' should be specified"
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ files=["file1", "file2", "file3"],
+ pattern="abc",
+ )
+
+
+def test_incorrect_params_emtpy_table():
+ exception_message = "table_name shouldn't be empty"
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="",
+ )
+
+
+def test_incorrect_params_emtpy_location():
+ exception_message = "file_location shouldn't be empty"
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ file_location="",
+ file_format="JSON",
+ table_name="abc",
+ )
+
+
+def test_incorrect_params_wrong_format():
+ file_format = "JSONL"
+ exception_message = f"file_format '{file_format}' isn't supported"
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ file_location=COPY_FILE_LOCATION,
+ file_format=file_format,
+ table_name="abc",
+ )
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py
b/tests/providers/databricks/operators/test_databricks_sql.py
index 2663deeec2..8489f45095 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -19,42 +19,118 @@ from __future__ import annotations
import os
import tempfile
-from unittest import mock
+from unittest.mock import patch
import pytest
from databricks.sql.types import Row
-from airflow import AirflowException
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
-from airflow.providers.databricks.operators.databricks_sql import (
- DatabricksCopyIntoOperator,
- DatabricksSqlOperator,
-)
+from airflow.providers.databricks.operators.databricks_sql import
DatabricksSqlOperator
DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
DEFAULT_CONN_ID = "databricks_default"
-COPY_FILE_LOCATION = "s3://my-bucket/jsonData"
-class TestDatabricksSqlOperator:
-
@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
- def test_exec_success(self, db_mock_class):
- """
- Test the execute function in case where SQL query was successful.
- """
- sql = "select * from dummy"
- op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, do_xcom_push=True)
[email protected](
+ "sql, return_last, split_statement, hook_results, hook_descriptions,
expected_results",
+ [
+ pytest.param(
+ "select * from dummy",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statement, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy;select * from dummy2",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Multiple SQL statements, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ False,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statements, no return_last (doesn't
matter), no split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ True,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ id="Scalar: Single SQL statements, return_last (doesn't matter),
no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy"],
+ False,
+ False,
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ [[("id",), ("value",)]],
+ [([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")])],
+ id="Non-Scalar: Single SQL statements in list, no return_last, no
split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ False,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ ([("id2",), ("value2",)], [Row(id2=1, value2="value1"),
Row(id2=2, value2="value2")]),
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, no return_last
(no matter), no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ True,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ ([("id",), ("value",)], [Row(id=1, value="value1"), Row(id=2,
value="value2")]),
+ ([("id2",), ("value2",)], [Row(id2=1, value2="value1"),
Row(id2=2, value2="value2")]),
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, return_last (no
matter), no split statement",
+ ),
+ ],
+)
+def test_exec_success(sql, return_last, split_statement, hook_results,
hook_descriptions, expected_results):
+ """
+ Test the execute function in case where SQL query was successful.
+ """
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class:
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql=sql,
+ do_xcom_push=True,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
db_mock = db_mock_class.return_value
- mock_description = [("id",), ("value",)]
- mock_results = [Row(id=1, value="value1")]
- db_mock.run.return_value = mock_results
- db_mock.last_description = mock_description
- db_mock.scalar_return_last = False
+ db_mock.run.return_value = hook_results
+ db_mock.descriptions = hook_descriptions
execute_results = op.execute(None)
- assert execute_results == (mock_description, mock_results)
+ assert execute_results == expected_results
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
http_path=None,
@@ -70,32 +146,116 @@ class TestDatabricksSqlOperator:
parameters=None,
handler=fetch_all_handler,
autocommit=False,
- return_last=True,
- split_statements=False,
+ return_last=return_last,
+ split_statements=split_statement,
)
-
@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
- def test_exec_write_file(self, db_mock_class):
- """
- Test the execute function in case where SQL query was successful and
data is written as CSV
- """
- sql = "select * from dummy"
+
[email protected](
+ "return_last, split_statements, sql, descriptions, hook_results",
+ [
+ pytest.param(
+ True,
+ False,
+ "select * from dummy",
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: return_last True and split_statement False",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from dummy",
+ [[("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ id="Non-Scalar: return_last False and split_statement True",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from dummy",
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: return_last True and no split_statement True",
+ ),
+ pytest.param(
+ False,
+ False,
+ "select * from dummy",
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: return_last False and split_statement is False",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from dummy2; select * from dummy",
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [
+ [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ ],
+ id="Non-Scalar: return_last False and split_statement is True",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from dummy2; select * from dummy",
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: return_last True and split_statement is True",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from dummy2; select * from dummy",
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: return_last True and split_statement is True",
+ ),
+ pytest.param(
+ True,
+ True,
+ ["select * from dummy2", "select * from dummy"],
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ id="Non-Scalar: sql is list and return_last is True",
+ ),
+ pytest.param(
+ False,
+ True,
+ ["select * from dummy2", "select * from dummy"],
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ id="Non-Scalar: sql is list and return_last is False",
+ ),
+ ],
+)
+def test_exec_write_file(return_last, split_statements, sql, descriptions,
hook_results):
+ """
+ Test the execute function in case where SQL query was successful and data
is written as CSV
+ """
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as db_mock_class:
tempfile_path = tempfile.mkstemp()[1]
- op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql,
output_path=tempfile_path)
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql=sql,
+ output_path=tempfile_path,
+ return_last=return_last,
+ split_statements=split_statements,
+ )
db_mock = db_mock_class.return_value
- mock_description = [("id",), ("value",)]
- mock_results = [Row(id=1, value="value1")]
+ mock_results = hook_results
db_mock.run.return_value = mock_results
- db_mock.last_description = mock_description
- db_mock.scalar_return_last = False
+ db_mock.descriptions = descriptions
try:
op.execute(None)
results = [line.strip() for line in open(tempfile_path)]
finally:
os.remove(tempfile_path)
-
- assert results == ["id,value", "1,value1"]
+ # In all cases only result of last query i output as file
+ assert results == ["id,value", "1,value1", "2,value2"]
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
http_path=None,
@@ -111,198 +271,6 @@ class TestDatabricksSqlOperator:
parameters=None,
handler=fetch_all_handler,
autocommit=False,
- return_last=True,
- split_statements=False,
- )
-
-
-class TestDatabricksSqlCopyIntoOperator:
- def test_copy_with_files(self):
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="JSON",
- table_name="test",
- files=["file1", "file2", "file3"],
- format_options={"dateFormat": "yyyy-MM-dd"},
- task_id=TASK_ID,
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM '{COPY_FILE_LOCATION}'
-FILEFORMAT = JSON
-FILES = ('file1','file2','file3')
-FORMAT_OPTIONS ('dateFormat' = 'yyyy-MM-dd')
-""".strip()
- )
-
- def test_copy_with_expression(self):
- expression = "col1, col2"
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="CSV",
- table_name="test",
- task_id=TASK_ID,
- pattern="folder1/file_[a-g].csv",
- expression_list=expression,
- format_options={"header": "true"},
- force_copy=True,
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}')
-FILEFORMAT = CSV
-PATTERN = 'folder1/file_[a-g].csv'
-FORMAT_OPTIONS ('header' = 'true')
-COPY_OPTIONS ('force' = 'true')
-""".strip()
- )
-
- def test_copy_with_credential(self):
- expression = "col1, col2"
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="CSV",
- table_name="test",
- task_id=TASK_ID,
- expression_list=expression,
- credential={"AZURE_SAS_TOKEN": "abc"},
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL
(AZURE_SAS_TOKEN = 'abc') ))
-FILEFORMAT = CSV
-""".strip()
+ return_last=return_last,
+ split_statements=split_statements,
)
-
- def test_copy_with_target_credential(self):
- expression = "col1, col2"
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="CSV",
- table_name="test",
- task_id=TASK_ID,
- expression_list=expression,
- storage_credential="abc",
- credential={"AZURE_SAS_TOKEN": "abc"},
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test WITH (CREDENTIAL abc)
-FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL
(AZURE_SAS_TOKEN = 'abc') ))
-FILEFORMAT = CSV
-""".strip()
- )
-
- def test_copy_with_encryption(self):
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="CSV",
- table_name="test",
- task_id=TASK_ID,
- encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM '{COPY_FILE_LOCATION}' WITH ( ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY
= 'abc'))
-FILEFORMAT = CSV
-""".strip()
- )
-
- def test_copy_with_encryption_and_credential(self):
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="CSV",
- table_name="test",
- task_id=TASK_ID,
- encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
- credential={"AZURE_SAS_TOKEN": "abc"},
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """
- """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
-FILEFORMAT = CSV
-""".strip()
- )
-
- def test_copy_with_validate_all(self):
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="JSON",
- table_name="test",
- task_id=TASK_ID,
- validate=True,
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM '{COPY_FILE_LOCATION}'
-FILEFORMAT = JSON
-VALIDATE ALL
-""".strip()
- )
-
- def test_copy_with_validate_N_rows(self):
- op = DatabricksCopyIntoOperator(
- file_location=COPY_FILE_LOCATION,
- file_format="JSON",
- table_name="test",
- task_id=TASK_ID,
- validate=10,
- )
- assert (
- op._create_sql_query()
- == f"""COPY INTO test
-FROM '{COPY_FILE_LOCATION}'
-FILEFORMAT = JSON
-VALIDATE 10 ROWS
-""".strip()
- )
-
- def test_incorrect_params_files_patterns(self):
- exception_message = "Only one of 'pattern' or 'files' should be
specified"
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksCopyIntoOperator(
- task_id=TASK_ID,
- file_location=COPY_FILE_LOCATION,
- file_format="JSON",
- table_name="test",
- files=["file1", "file2", "file3"],
- pattern="abc",
- )
-
- def test_incorrect_params_emtpy_table(self):
- exception_message = "table_name shouldn't be empty"
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksCopyIntoOperator(
- task_id=TASK_ID,
- file_location=COPY_FILE_LOCATION,
- file_format="JSON",
- table_name="",
- )
-
- def test_incorrect_params_emtpy_location(self):
- exception_message = "file_location shouldn't be empty"
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksCopyIntoOperator(
- task_id=TASK_ID,
- file_location="",
- file_format="JSON",
- table_name="abc",
- )
-
- def test_incorrect_params_wrong_format(self):
- file_format = "JSONL"
- exception_message = f"file_format '{file_format}' isn't supported"
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksCopyIntoOperator(
- task_id=TASK_ID,
- file_location=COPY_FILE_LOCATION,
- file_format=file_format,
- table_name="abc",
- )
diff --git a/tests/providers/jdbc/operators/test_jdbc.py
b/tests/providers/jdbc/operators/test_jdbc.py
index e027bdb96d..1b149bbb1f 100644
--- a/tests/providers/jdbc/operators/test_jdbc.py
+++ b/tests/providers/jdbc/operators/test_jdbc.py
@@ -50,5 +50,7 @@ class TestJdbcOperator:
sql=jdbc_operator.sql,
autocommit=jdbc_operator.autocommit,
parameters=jdbc_operator.parameters,
+ handler=None,
+ return_last=True,
split_statements=False,
)