This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 322ce0ccfa5 [v2-10-test] Improve speed of tests by not creating 
connections at parse time (#45690) (#45826)
322ce0ccfa5 is described below

commit 322ce0ccfa5f5cc4d671a9d470a589636a9c6b70
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Jan 21 17:45:51 2025 +0530

    [v2-10-test] Improve speed of tests by not creating connections at parse 
time (#45690) (#45826)
    
    The DAG serialization tests load all of the example and system test DAGs, 
and
    there were two places that these tests opened connections at parse time
    resulting in loads of extra of test time.
    
    - The SystemTestContextBuilder was trying to fetch things from SSM. This was
      addressed by adding a functools.cache on the function
    - The Bedrock example dag was setting/caching the underlying conn object
      globally. This was addressed by making the Airflow connection a global,
      rather than the Bedrock conn. This fix is not _great_, but it does 
massively
      help
    
    Before:
    
    > 111 passed, 1 warning in 439.37s (0:07:19)
    
    After:
    
    > 111 passed, 1 warning in 71.76s (0:01:11)
    (cherry picked from commit 102e853)
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
 .../amazon/aws/system/utils/test_helpers.py        | 11 ++++++--
 .../aws/example_bedrock_retrieve_and_generate.py   | 32 ++++++++++++----------
 .../system/providers/amazon/aws/utils/__init__.py  |  2 ++
 3 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/tests/providers/amazon/aws/system/utils/test_helpers.py 
b/tests/providers/amazon/aws/system/utils/test_helpers.py
index f48de1788b7..3af3720688a 100644
--- a/tests/providers/amazon/aws/system/utils/test_helpers.py
+++ b/tests/providers/amazon/aws/system/utils/test_helpers.py
@@ -24,7 +24,7 @@ from __future__ import annotations
 import os
 import sys
 from io import StringIO
-from unittest.mock import ANY, patch
+from unittest.mock import patch
 
 import pytest
 from moto import mock_aws
@@ -79,8 +79,15 @@ class TestAmazonSystemTestHelpers:
     ) -> None:
         mock_getenv.return_value = env_value or ssm_value
 
-        result = utils.fetch_variable(ANY, default_value) if default_value 
else utils.fetch_variable(ANY_STR)
+        utils._fetch_from_ssm.cache_clear()
 
+        result = (
+            utils.fetch_variable("some_key", default_value)
+            if default_value
+            else utils.fetch_variable(ANY_STR)
+        )
+
+        utils._fetch_from_ssm.cache_clear()
         assert result == expected_result
 
     def test_fetch_variable_no_value_found_raises_exception(self):
diff --git 
a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py 
b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
index fcebc8c40a0..2b7bce2fecd 100644
--- a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
+++ b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
@@ -127,7 +127,7 @@ def create_opensearch_policies(bedrock_role_arn: str, 
collection_name: str, poli
 
     def _create_security_policy(name, policy_type, policy):
         try:
-            aoss_client.create_security_policy(name=name, 
policy=json.dumps(policy), type=policy_type)
+            aoss_client.conn.create_security_policy(name=name, 
policy=json.dumps(policy), type=policy_type)
         except ClientError as e:
             if e.response["Error"]["Code"] == "ConflictException":
                 log.info("OpenSearch security policy %s already exists.", name)
@@ -135,7 +135,7 @@ def create_opensearch_policies(bedrock_role_arn: str, 
collection_name: str, poli
 
     def _create_access_policy(name, policy_type, policy):
         try:
-            aoss_client.create_access_policy(name=name, 
policy=json.dumps(policy), type=policy_type)
+            aoss_client.conn.create_access_policy(name=name, 
policy=json.dumps(policy), type=policy_type)
         except ClientError as e:
             if e.response["Error"]["Code"] == "ConflictException":
                 log.info("OpenSearch data access policy %s already exists.", 
name)
@@ -204,9 +204,9 @@ def create_collection(collection_name: str):
     :param collection_name: The name of the Collection to create.
     """
     log.info("\nCreating collection: %s.", collection_name)
-    return aoss_client.create_collection(name=collection_name, 
type="VECTORSEARCH")["createCollectionDetail"][
-        "id"
-    ]
+    return aoss_client.conn.create_collection(name=collection_name, 
type="VECTORSEARCH")[
+        "createCollectionDetail"
+    ]["id"]
 
 
 @task
@@ -317,7 +317,7 @@ def get_collection_arn(collection_id: str):
     """
     return next(
         colxn["arn"]
-        for colxn in aoss_client.list_collections()["collectionSummaries"]
+        for colxn in aoss_client.conn.list_collections()["collectionSummaries"]
         if colxn["id"] == collection_id
     )
 
@@ -336,7 +336,9 @@ def delete_data_source(knowledge_base_id: str, 
data_source_id: str):
     :param data_source_id: The unique identifier of the data source to delete.
     """
     log.info("Deleting data source %s from Knowledge Base %s.", 
data_source_id, knowledge_base_id)
-    bedrock_agent_client.delete_data_source(dataSourceId=data_source_id, 
knowledgeBaseId=knowledge_base_id)
+    bedrock_agent_client.conn.delete_data_source(
+        dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id
+    )
 
 
 # [END howto_operator_bedrock_delete_data_source]
@@ -355,7 +357,7 @@ def delete_knowledge_base(knowledge_base_id: str):
     :param knowledge_base_id: The unique identifier of the knowledge base to 
delete.
     """
     log.info("Deleting Knowledge Base %s.", knowledge_base_id)
-    
bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=knowledge_base_id)
+    
bedrock_agent_client.conn.delete_knowledge_base(knowledgeBaseId=knowledge_base_id)
 
 
 # [END howto_operator_bedrock_delete_knowledge_base]
@@ -393,7 +395,7 @@ def delete_collection(collection_id: str):
     :param collection_id: ID of the collection to be indexed.
     """
     log.info("Deleting collection %s.", collection_id)
-    aoss_client.delete_collection(id=collection_id)
+    aoss_client.conn.delete_collection(id=collection_id)
 
 
 @task(trigger_rule=TriggerRule.ALL_DONE)
@@ -404,7 +406,7 @@ def delete_opensearch_policies(collection_name: str):
     :param collection_name: All policies in the given collection name will be 
deleted.
     """
 
-    access_policies = aoss_client.list_access_policies(
+    access_policies = aoss_client.conn.list_access_policies(
         type="data", resource=[f"collection/{collection_name}"]
     )["accessPolicySummaries"]
     log.info("Found access policies for %s: %s", collection_name, 
access_policies)
@@ -412,10 +414,10 @@ def delete_opensearch_policies(collection_name: str):
         raise Exception("No access policies found?")
     for policy in access_policies:
         log.info("Deleting access policy for %s: %s", collection_name, 
policy["name"])
-        aoss_client.delete_access_policy(name=policy["name"], type="data")
+        aoss_client.conn.delete_access_policy(name=policy["name"], type="data")
 
     for policy_type in ["encryption", "network"]:
-        policies = aoss_client.list_security_policies(
+        policies = aoss_client.conn.list_security_policies(
             type=policy_type, resource=[f"collection/{collection_name}"]
         )["securityPolicySummaries"]
         if not policies:
@@ -423,7 +425,7 @@ def delete_opensearch_policies(collection_name: str):
         log.info("Found %s security policies for %s: %s", policy_type, 
collection_name, policies)
         for policy in policies:
             log.info("Deleting %s security policy for %s: %s", policy_type, 
collection_name, policy["name"])
-            aoss_client.delete_security_policy(name=policy["name"], 
type=policy_type)
+            aoss_client.conn.delete_security_policy(name=policy["name"], 
type=policy_type)
 
 
 with DAG(
@@ -436,8 +438,8 @@ with DAG(
     test_context = sys_test_context_task()
     env_id = test_context["ENV_ID"]
 
-    aoss_client = OpenSearchServerlessHook(aws_conn_id=None).conn
-    bedrock_agent_client = BedrockAgentHook(aws_conn_id=None).conn
+    aoss_client = OpenSearchServerlessHook(aws_conn_id=None)
+    bedrock_agent_client = BedrockAgentHook(aws_conn_id=None)
 
     region_name = boto3.session.Session().region_name
 
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py 
b/tests/system/providers/amazon/aws/utils/__init__.py
index 8b4114fc90a..411f92ab7bf 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import functools
 import inspect
 import json
 import logging
@@ -92,6 +93,7 @@ def _validate_env_id(env_id: str) -> str:
     return env_id.lower()
 
 
[email protected]
 def _fetch_from_ssm(key: str, test_name: str | None = None) -> str:
     """
     Test values are stored in the SSM Value as a JSON-encoded dict of 
key/value pairs.

Reply via email to