This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f7f3b675ec Make sure that only valid elasticsearch keys are passed to 
handler (#34119)
f7f3b675ec is described below

commit f7f3b675ecd40e32e458b71b5066864f866a60c8
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Sep 7 02:43:06 2023 +0200

    Make sure that only valid elasticsearch keys are passed to handler (#34119)
    
    The elasticsearch handler got all configuraiton parameters
    from the "elasticsearch_config" section but it means that in
    airflow versions pre 2.7 it could get old config keys which renders
    the new provider unusable.
    
    This PR filters out configuration parameter to only pass valid
    parameters for the new handler.
    
    Fixes: #34099
---
 .../providers/elasticsearch/log/es_task_handler.py |  37 ++++++--
 .../elasticsearch/log/test_es_task_handler.py      | 105 ++++++++++++---------
 2 files changed, 92 insertions(+), 50 deletions(-)

diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py 
b/airflow/providers/elasticsearch/log/es_task_handler.py
index d0d723c7dc..944fe88cf9 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import inspect
 import logging
 import sys
 import warnings
@@ -30,6 +31,7 @@ from urllib.parse import quote, urlparse
 import elasticsearch
 import pendulum
 from elasticsearch.exceptions import NotFoundError
+from typing_extensions import Literal
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -56,6 +58,32 @@ EsLogMsgType = List[Tuple[str, str]]
 USE_PER_RUN_LOG_ID = hasattr(DagRun, "get_log_template")
 
 
+VALID_ES_CONFIG_KEYS = 
set(inspect.signature(elasticsearch.Elasticsearch.__init__).parameters.keys())
+# Remove `self` from the valid set of kwargs
+VALID_ES_CONFIG_KEYS.remove("self")
+
+
+def get_es_kwargs_from_config() -> dict[str, Any]:
+    elastic_search_config = conf.getsection("elasticsearch_configs")
+    kwargs_dict = (
+        {key: value for key, value in elastic_search_config.items() if key in 
VALID_ES_CONFIG_KEYS}
+        if elastic_search_config
+        else {}
+    )
+    # For elasticsearch>8 retry_timeout have changed for elasticsearch to 
retry_on_timeout
+    # in Elasticsearch() compared to previous versions.
+    # Read more at: 
https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
+    if (
+        elastic_search_config
+        and "retry_timeout" in elastic_search_config
+        and not kwargs_dict.get("retry_on_timeout")
+    ):
+        retry_timeout = elastic_search_config.get("retry_timeout")
+        if retry_timeout is not None:
+            kwargs_dict["retry_on_timeout"] = retry_timeout
+    return kwargs_dict
+
+
 class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, 
LoggingMixin):
     """
     ElasticsearchTaskHandler is a python log handler that reads logs from 
Elasticsearch.
@@ -95,17 +123,14 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
         host: str = "http://localhost:9200";,
         frontend: str = "localhost:5601",
         index_patterns: str | None = conf.get("elasticsearch", 
"index_patterns", fallback="_all"),
-        es_kwargs: dict | None = conf.getsection("elasticsearch_configs"),
+        es_kwargs: dict | None | Literal["default_es_kwargs"] = 
"default_es_kwargs",
         *,
         filename_template: str | None = None,
         log_id_template: str | None = None,
     ):
         es_kwargs = es_kwargs or {}
-        # For elasticsearch>8,arguments like retry_timeout have changed for 
elasticsearch to retry_on_timeout
-        # in Elasticsearch() compared to previous versions.
-        # Read more at: 
https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
-        if es_kwargs.get("retry_timeout"):
-            es_kwargs["retry_on_timeout"] = es_kwargs.pop("retry_timeout")
+        if es_kwargs == "default_es_kwargs":
+            es_kwargs = get_es_kwargs_from_config()
         host = self.format_url(host)
         super().__init__(base_log_folder, filename_template)
         self.closed = False
diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py 
b/tests/providers/elasticsearch/log/test_es_task_handler.py
index a8af384b7b..3d3bf4cc43 100644
--- a/tests/providers/elasticsearch/log/test_es_task_handler.py
+++ b/tests/providers/elasticsearch/log/test_es_task_handler.py
@@ -23,6 +23,7 @@ import logging
 import os
 import re
 import shutil
+from pathlib import Path
 from unittest import mock
 from urllib.parse import quote
 
@@ -32,15 +33,24 @@ import pytest
 
 from airflow.configuration import conf
 from airflow.providers.elasticsearch.log.es_response import 
ElasticSearchResponse
-from airflow.providers.elasticsearch.log.es_task_handler import 
ElasticsearchTaskHandler, getattr_nested
+from airflow.providers.elasticsearch.log.es_task_handler import (
+    VALID_ES_CONFIG_KEYS,
+    ElasticsearchTaskHandler,
+    get_es_kwargs_from_config,
+    getattr_nested,
+)
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.timezone import datetime
+from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs
 
 from .elasticmock import elasticmock
 from .elasticmock.utilities import SearchFailedException
 
+AIRFLOW_SOURCES_ROOT_DIR = Path(__file__).parents[4].resolve()
+ES_PROVIDER_YAML_FILE = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "providers" / 
"elasticsearch" / "provider.yaml"
+
 
 def get_ti(dag_id, task_id, execution_date, create_task_instance):
     ti = create_task_instance(
@@ -145,49 +155,6 @@ class TestElasticsearchTaskHandler:
         else:
             assert ElasticsearchTaskHandler.format_url(host) == expected
 
-    def test_elasticsearch_constructor_retry_timeout_handling(self):
-        """
-        Test if the ElasticsearchTaskHandler constructor properly handles the 
retry_timeout argument.
-        """
-        # Mock the Elasticsearch client
-        with mock.patch(
-            
"airflow.providers.elasticsearch.log.es_task_handler.elasticsearch.Elasticsearch"
-        ) as mock_es:
-            # Test when 'retry_timeout' is present in es_kwargs
-            es_kwargs = {"retry_timeout": 10}
-            ElasticsearchTaskHandler(
-                base_log_folder="dummy_folder",
-                end_of_log_mark="end_of_log_mark",
-                write_stdout=False,
-                json_format=False,
-                json_fields="fields",
-                host_field="host",
-                offset_field="offset",
-                es_kwargs=es_kwargs,
-            )
-
-            # Check the arguments with which the Elasticsearch client is 
instantiated
-            mock_es.assert_called_once_with("http://localhost:9200";, 
retry_on_timeout=10)
-
-            # Reset the mock for the next test
-            mock_es.reset_mock()
-
-            # Test when 'retry_timeout' is not present in es_kwargs
-            es_kwargs = {}
-            ElasticsearchTaskHandler(
-                base_log_folder="dummy_folder",
-                end_of_log_mark="end_of_log_mark",
-                write_stdout=False,
-                json_format=False,
-                json_fields="fields",
-                host_field="host",
-                offset_field="offset",
-                es_kwargs=es_kwargs,
-            )
-
-            # Check that the Elasticsearch client is instantiated without the 
'retry_on_timeout' argument
-            mock_es.assert_called_once_with("http://localhost:9200";)
-
     def test_client(self):
         assert isinstance(self.es_task_handler.client, 
elasticsearch.Elasticsearch)
         assert self.es_task_handler.index_patterns == "_all"
@@ -691,3 +658,53 @@ def test_safe_attrgetter():
     assert getattr_nested(a, "aa", "heya") == "heya"  # respects non-none 
default
     assert getattr_nested(a, "c", "heya") is None  # respects none value
     assert getattr_nested(a, "aa", None) is None  # respects none default
+
+
+def test_retrieve_config_keys():
+    """
+    Tests that the ElasticsearchTaskHandler retrieves the correct 
configuration keys from the config file.
+    * old_parameters are removed
+    * parameters from config are automatically added
+    * constructor parameters missing from config are also added
+    :return:
+    """
+    with conf_vars(
+        {
+            ("elasticsearch_configs", "use_ssl"): "True",
+            ("elasticsearch_configs", "http_compress"): "False",
+            ("elasticsearch_configs", "timeout"): "10",
+        }
+    ):
+        args_from_config = get_es_kwargs_from_config().keys()
+        # use_ssl is removed from config
+        assert "use_ssl" not in args_from_config
+        # verify_certs comes from default config value
+        assert "verify_certs" in args_from_config
+        # timeout comes from config provided value
+        assert "timeout" in args_from_config
+        # http_compress comes from config value
+        assert "http_compress" in args_from_config
+        assert "self" not in args_from_config
+
+
+def test_retrieve_retry_on_timeout():
+    """
+    Test if retrieve timeout is converted to retry_on_timeout.
+    """
+    with conf_vars(
+        {
+            ("elasticsearch_configs", "retry_timeout"): "True",
+        }
+    ):
+        args_from_config = get_es_kwargs_from_config().keys()
+        # use_ssl is removed from config
+        assert "retry_timeout" not in args_from_config
+        # verify_certs comes from default config value
+        assert "retry_on_timeout" in args_from_config
+
+
+def test_self_not_valid_arg():
+    """
+    Test if self is not a valid argument.
+    """
+    assert "self" not in VALID_ES_CONFIG_KEYS

Reply via email to