alexott commented on code in PR #27912:
URL: https://github.com/apache/airflow/pull/27912#discussion_r1032463169
##########
airflow/providers/common/sql/hooks/sql.py:
##########
@@ -30,6 +30,25 @@
from airflow.version import version
+def has_scalar_return_value(sql: str | Iterable[str], return_last: bool,
split_statements: bool):
+ """
+ Determines when scalar value should be returned.
+
+ Scalar value should be returned when:
Review Comment:
I was thinking about it - I think that terminology is confusing a bit. When
we think about `scalar` we think about one value. But execution of query always
return an iterator of rows (IIRc).
So at the end we have following possibilities:
* a list of query results (list of lists of rows) - when we have multiple
queries
* a single query result (list of rows) - when we have one query or when we
use `return_last`
maybe rename it to something like `single_result` or something like that ?
##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -225,54 +230,37 @@ def __init__(
self.split_statements = split_statements
self.return_last = return_last
- @overload
- def _process_output(
- self, results: Any, description: Sequence[Sequence] | None,
scalar_results: Literal[True]
- ) -> Any:
- pass
-
- @overload
- def _process_output(
- self, results: list[Any], description: Sequence[Sequence] | None,
scalar_results: Literal[False]
- ) -> Any:
- pass
-
- def _process_output(
- self, results: Any | list[Any], description: Sequence[Sequence] |
None, scalar_results: bool
- ) -> Any:
+ def _process_output(self, results: list[Any], descriptions:
list[Sequence[Sequence] | None]) -> list[Any]:
"""
- Can be overridden by the subclass in case some extra processing is
needed.
+ Processes output before it is returned by the operator.
+
+ It can be overridden by the subclass in case some extra processing is
needed.
The "process_output" method can override the returned output -
augmenting or processing the
output as needed - the output returned will be returned as execute
return value and if
- do_xcom_push is set to True, it will be set as XCom returned
+ do_xcom_push is set to True, it will be set as XCom returned.
:param results: results in the form of list of rows.
- :param description: as returned by ``cur.description`` in the Python
DBAPI
- :param scalar_results: True if result is single scalar value rather
than list of rows
+ :param descriptions: list of descriptions returned by
``cur.description`` in the Python DBAPI
"""
return results
def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
- if self.do_xcom_push:
- output = hook.run(
- sql=self.sql,
- autocommit=self.autocommit,
- parameters=self.parameters,
- handler=self.handler,
- split_statements=self.split_statements,
- return_last=self.return_last,
- )
- else:
- output = hook.run(
- sql=self.sql,
- autocommit=self.autocommit,
- parameters=self.parameters,
- split_statements=self.split_statements,
- )
-
- return self._process_output(output, hook.last_description,
hook.scalar_return_last)
+ output = hook.run(
+ sql=self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ handler=self.handler if self.do_xcom_push else None,
+ split_statements=self.split_statements,
+ return_last=self.return_last,
+ )
+ if has_scalar_return_value(self.sql, self.return_last,
self.split_statements):
+ # For simplicity, we pass always list as input to _process_output,
regardless if
+ # scalar is going to be returned, and we return the first element
of the list in this case
+ # from the list returned by _process_output
+ return self._process_output([output], hook.descriptions)[0]
Review Comment:
why we return first element? Usually we're interested in the last result
(original behavior of DBSQL operator)
##########
airflow/providers/databricks/hooks/databricks_sql.py:
##########
@@ -163,38 +163,43 @@ def run(
:param return_last: Whether to return result for only last statement
or for all after split
:return: return only result of the LAST SQL expression if handler was
provided.
"""
- self.scalar_return_last = isinstance(sql, str) and return_last
+ self.descriptions = []
if isinstance(sql, str):
if split_statements:
- sql = self.split_sql_string(sql)
+ sql_list = [self.strip_sql_string(s) for s in
self.split_sql_string(sql)]
else:
- sql = [self.strip_sql_string(sql)]
+ sql_list = [self.strip_sql_string(sql)]
+ else:
+ sql_list = [self.strip_sql_string(s) for s in sql]
- if sql:
- self.log.debug("Executing following statements against Databricks
DB: %s", list(sql))
+ if sql_list:
+ self.log.debug("Executing following statements against Databricks
DB: %s", sql_list)
else:
raise ValueError("List of SQL statements is empty")
results = []
- for sql_statement in sql:
+ for sql_statement in sql_list:
# when using AAD tokens, it could expire if previous query run
longer than token lifetime
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)
-
if handler is not None:
result = handler(cur)
- results.append(result)
- self.last_description = cur.description
+ if has_scalar_return_value(sql, return_last,
split_statements):
+ results = [result]
+ self.descriptions = [cur.description]
+ else:
+ results.append(result)
+ self.descriptions.append(cur.description)
self._sql_conn = None
if handler is None:
return None
- elif self.scalar_return_last:
- return results[-1]
+ if has_scalar_return_value(sql, return_last, split_statements):
+ return results[0]
Review Comment:
Same here - we were always returning results of the last query
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]