This is an automated email from the ASF dual-hosted git repository.

csringhofer pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit d0127c2401fb5acccb33d3be87dabb79c976edc7
Author: Xuebin Su <[email protected]>
AuthorDate: Thu Dec 12 15:39:21 2024 +0800

    IMPALA-13566: Expose query cancellation status to UDFs
    
    Previously, the evaluation of a UDF was not interruptible. When the
    impalad was evaluating a UDF on a batch, the client had to wait
    for the end of processing the batch before it could get the results or
    cancel the query. Such a period could be long if the UDF takes a long
    time to run, e.g. the sleep() UDF.
    
    This patch tries to mitigate this issue by exposing the query
    cancellation status in RuntimeState to the UDF so that the UDF can check
    the status by itself and return early. This can significantly reduce the
    waiting time when cancelling a long-running query.
    
    As an example, this patch makes the sleep() UDF interruptible by
    checking the query cancellation status inside the UDF every 200ms,
    controlled by variable SLEEP_UNINTERRUPTIBLE_INTERVAL_MS. This means
    that we only need to wait about 200ms to cancel the query.
    
    Testing:
    - Added a new test case in tests/query_test/test_cancellation.py to
      ensure that we can interrupt sleep().
    - Added a new connection class called MinimalHS2Connection in
      tests/common/impala_connection.py to support manipulating one
      operation from multiple connections for the new test case.
    
    Change-Id: I9430167f7e46bbdf66153abb4645541cd8cf0142
    Reviewed-on: http://gerrit.cloudera.org:8080/22280
    Tested-by: Impala Public Jenkins <[email protected]>
    Reviewed-by: Csaba Ringhofer <[email protected]>
---
 be/src/exprs/utility-functions-ir.cc  |  21 +++++-
 be/src/udf/udf.cc                     |   9 +++
 be/src/udf/udf.h                      |   3 +
 tests/common/impala_connection.py     | 127 ++++++++++++++++++++++++++++++++++
 tests/query_test/test_cancellation.py |  27 ++++++++
 5 files changed, 185 insertions(+), 2 deletions(-)

diff --git a/be/src/exprs/utility-functions-ir.cc 
b/be/src/exprs/utility-functions-ir.cc
index 2d6ecfaec..470e84096 100644
--- a/be/src/exprs/utility-functions-ir.cc
+++ b/be/src/exprs/utility-functions-ir.cc
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <chrono>
+
 #include "exprs/utility-functions.h"
 #include <gutil/strings/substitute.h>
 #include "common/compiler-util.h"
@@ -33,6 +35,9 @@
 #include "common/names.h"
 
 using namespace strings;
+using namespace std;
+
+const chrono::milliseconds SLEEP_UNINTERRUPTIBLE_INTERVAL_MS{200};
 
 namespace impala {
 
@@ -164,9 +169,21 @@ StringVal UtilityFunctions::Uuid(FunctionContext* ctx) {
   return GenUuid(ctx);
 }
 
-BooleanVal UtilityFunctions::Sleep(FunctionContext* ctx, const IntVal& 
milliseconds ) {
+BooleanVal UtilityFunctions::Sleep(FunctionContext* ctx, const IntVal& 
milliseconds) {
   if (milliseconds.is_null) return BooleanVal::null();
-  SleepForMs(milliseconds.val);
+  for (auto remaining_sleep_duration = chrono::milliseconds(milliseconds.val);
+      remaining_sleep_duration > chrono::milliseconds(0);) {
+    if (ctx->IsQueryCancelled()) {
+      return BooleanVal(true);
+    }
+    auto start = chrono::steady_clock::now();
+    auto expected_sleep_duration =
+        min(remaining_sleep_duration, SLEEP_UNINTERRUPTIBLE_INTERVAL_MS);
+    SleepForMs(expected_sleep_duration.count());
+    auto actual_sleep_duration =
+        
chrono::duration_cast<chrono::milliseconds>(chrono::steady_clock::now() - 
start);
+    remaining_sleep_duration -= actual_sleep_duration;
+  }
   return BooleanVal(true);
 }
 
diff --git a/be/src/udf/udf.cc b/be/src/udf/udf.cc
index 51f0ec933..b62bbb750 100644
--- a/be/src/udf/udf.cc
+++ b/be/src/udf/udf.cc
@@ -104,6 +104,11 @@ class RuntimeState {
     return false;
   }
 
+  bool is_cancelled() const {
+    assert(false);
+    return false;
+  }
+
   bool LogError(const std::string& error) {
     assert(false);
     return false;
@@ -323,6 +328,10 @@ FunctionContext::UniqueId FunctionContext::query_id() 
const {
   return id;
 }
 
+bool FunctionContext::IsQueryCancelled() const {
+  return impl_->state_->is_cancelled();
+}
+
 bool FunctionContext::has_error() const {
   return !impl_->error_msg_.empty();
 }
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index c96f4731e..a8a867544 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -165,6 +165,9 @@ class FunctionContext {
   /// Returns the query_id for the current query.
   UniqueId query_id() const;
 
+  /// Returns whether the query is cancelled.
+  bool IsQueryCancelled() const;
+
   /// Sets an error for this UDF. The error message is copied and the copy is 
owned by
   /// this object.
   ///
diff --git a/tests/common/impala_connection.py 
b/tests/common/impala_connection.py
index b4285ec7d..caafe7d42 100644
--- a/tests/common/impala_connection.py
+++ b/tests/common/impala_connection.py
@@ -22,6 +22,7 @@
 from __future__ import absolute_import, division, print_function
 import abc
 from future.utils import with_metaclass
+import getpass
 import logging
 import re
 import time
@@ -29,6 +30,7 @@ import time
 from beeswaxd.BeeswaxService import QueryState
 import impala.dbapi as impyla
 import impala.error as impyla_error
+import impala.hiveserver2 as hs2
 import tests.common
 from RuntimeProfile.ttypes import TRuntimeProfileFormat
 from tests.beeswax.impala_beeswax import (
@@ -958,3 +960,128 @@ def create_connection(host_port, use_kerberos=False, 
protocol=BEESWAX,
 def create_ldap_connection(host_port, user, password, use_ssl=False):
   return BeeswaxConnection(host_port=host_port, user=user, password=password,
                            use_ssl=use_ssl)
+
+
+class MinimalHS2OperationHandle(OperationHandle):
+  def __str__(self):
+    return op_handle_to_query_id(self.get_handle())
+
+
+class MinimalHS2Connection(ImpalaConnection):
+  """
+  Connection to Impala using the HiveServer2 (HS2) protocol.
+
+  This class does not use Impyla's DB-API cursors. Instead, it is built 
directly on the
+  HS2 RPC layer to support manipulating one operation from multiple connections
+  concurrently.
+
+  This class is designed to be minimalistic to facilitate testing. Each method 
is mapped
+  to only one Thrift RPC.
+  """
+  def __init__(self, host_port, user=None):
+    self.__host_port = host_port
+    host, port = host_port.split(":")
+    self.__conn = hs2.connect(host, port, auth_mechanism='NOSASL')
+    self.__user = user if user is not None else getpass.getuser()
+    self.__session = self.__conn.open_session(self.__user)
+
+  def connect(self):
+    pass  # Do nothing
+
+  def close(self):
+    LOG.info("-- closing connection to: %s" % self.__host_port)
+    try:
+      self.__session.close()
+    finally:
+      self.__conn.close()
+
+  def execute(self, sql_stmt):  # noqa: U100
+    raise NotImplementedError()
+
+  def execute_async(self, sql_stmt):
+    hs2_operation = self.__session.execute(sql_stmt)
+    operation_handle = MinimalHS2OperationHandle(hs2_operation.handle, 
sql_stmt)
+    LOG.info("Started query {0}".format(operation_handle))
+    return operation_handle
+
+  def __get_operation(self, operation_handle):
+    return hs2.Operation(self.__session, operation_handle.get_handle())
+
+  def fetch(self, sql_stmt, operation_handle, max_rows=-1):  # noqa: U100
+    """
+    Fetch the results of the query. It will block the current connection if 
the results
+    are not available yet.
+    """
+    LOG.info("-- fetching results from: {0}".format(operation_handle))
+    return self.__get_operation(operation_handle).fetch(max_rows=max_rows)
+
+  def fetch_error(self, operation_handle):
+    """
+    Fetch the error of the query.
+    """
+    try:
+      self.fetch(None, operation_handle)
+      assert False, "Failed to catch the error of the query."
+    except Exception as exc:
+      return exc
+
+  def get_state(self, operation_handle):
+    return self.__get_operation(operation_handle).get_status()
+
+  def wait_for(self, operation_handle, timeout_s=60):
+    """
+    Wait until the query is in a terminal state.
+    """
+    start_time = time.time()
+    while True:
+      operation_state = self.get_state(operation_handle)
+      if operation_state not in ("PENDING_STATE", "INITIALIZED_STATE", 
"RUNNING_STATE"):
+        return operation_state
+      if time.time() - start_time > timeout_s:
+        raise Exception("Timed out waiting for the query")
+      time.sleep(0.1)
+
+  def cancel(self, operation_handle):
+    LOG.info("-- canceling operation: {0}".format(operation_handle))
+    return self.__get_operation(operation_handle).cancel()
+
+  def close_query(self, operation_handle):
+    LOG.info("-- closing query for operation handle: 
{0}".format(operation_handle))
+    return self.__get_operation(operation_handle).close()
+
+  def state_is_finished(self, operation_handle):  # noqa: U100
+    raise NotImplementedError()
+
+  def get_log(self, operation_handle):
+    return self.__get_operation(operation_handle).get_log()
+
+  def set_configuration_option(self, name, value):  # noqa: U100
+    raise NotImplementedError()
+
+  def clear_configuration(self):
+    raise NotImplementedError()
+
+  def get_default_configuration(self):
+    raise NotImplementedError()
+
+  def get_host_port(self):
+    return self.__host_port
+
+  def get_test_protocol(self):
+    return HS2
+
+  def handle_id(self, operation_handle):  # noqa: U100
+    return str(operation_handle)
+
+  def get_admission_result(self, operation_handle):  # noqa: U100
+    raise NotImplementedError()
+
+  def get_impala_exec_state(self, operation_handle):  # noqa: U100
+    raise NotImplementedError()
+
+  def get_runtime_profile(self, operation_handle,  # noqa: U100
+                          profile_format=TRuntimeProfileFormat.STRING):  # 
noqa: U100
+    raise NotImplementedError()
+
+  def wait_for_admission_control(self, operation_handle, timeout_s=60):  # 
noqa: U100
+    raise NotImplementedError()
diff --git a/tests/query_test/test_cancellation.py 
b/tests/query_test/test_cancellation.py
index 6c6570d97..f6ce4c75b 100644
--- a/tests/query_test/test_cancellation.py
+++ b/tests/query_test/test_cancellation.py
@@ -21,13 +21,16 @@
 from __future__ import absolute_import, division, print_function
 from builtins import range
 import pytest
+from threading import Thread
 from time import sleep
 from RuntimeProfile.ttypes import TRuntimeProfileFormat
 from tests.common.test_dimensions import add_mandatory_exec_option
 from tests.common.test_vector import ImpalaTestDimension
 from tests.common.impala_test_suite import ImpalaTestSuite
+from tests.common.impala_test_suite import IMPALAD_HS2_HOST_PORT
 from tests.util.cancel_util import cancel_query_and_validate_state
 from tests.verifiers.metric_verifier import MetricVerifier
+from tests.common.impala_connection import MinimalHS2Connection
 
 # PRIMARY KEY for lineitem
 LINEITEM_PK = 'l_orderkey, l_partkey, l_suppkey, l_linenumber'
@@ -219,6 +222,30 @@ class TestCancellation(ImpalaTestSuite):
       assert k == 'Plan' or '\n\n' not in v, \
         "Profile contains repeating newlines: %s %s" % (k, v)
 
+  def test_interrupt_sleep(self):
+    query = "SELECT sleep(100000) FROM functional.alltypes;"
+    with MinimalHS2Connection(IMPALAD_HS2_HOST_PORT) as cancel_client:
+      query_handle = cancel_client.execute_async(query)
+      assert cancel_client.wait_for(query_handle) == "FINISHED_STATE"
+
+      class FetchInAnotherConnection:
+        def __call__(self):
+          with MinimalHS2Connection(IMPALAD_HS2_HOST_PORT) as fetch_client:
+            self.error = fetch_client.fetch_error(query_handle)
+
+      fetch_in_another_connection = FetchInAnotherConnection()
+      fetch_thread = Thread(target=fetch_in_another_connection)
+      fetch_thread.start()
+      sleep(1)  # Wait for another thread to start fetching.
+      cancel_client.cancel(query_handle)
+      # Timeout for join() needs to be longer than 
SLEEP_UNINTERRUPTIBLE_INTERVAL_MS.
+      fetch_thread.join(3)
+      assert not fetch_thread.is_alive(), "Failed to terminate the fetching 
thread."
+      assert "Cancelled" in str(fetch_in_another_connection.error)
+      assert "Invalid or unknown query handle" in str(
+          cancel_client.fetch_error(query_handle))
+
+
   def teardown_method(self, method):
     # For some reason it takes a little while for the query to get completely 
torn down
     # when the debug action is WAIT, causing 
TestValidateMetrics.test_metrics_are_zero to

Reply via email to