This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 322ce0ccfa5 [v2-10-test] Improve speed of tests by not creating
connections at parse time (#45690) (#45826)
322ce0ccfa5 is described below
commit 322ce0ccfa5f5cc4d671a9d470a589636a9c6b70
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Jan 21 17:45:51 2025 +0530
[v2-10-test] Improve speed of tests by not creating connections at parse
time (#45690) (#45826)
The DAG serialization tests load all of the example and system test DAGs,
and
there were two places that these tests opened connections at parse time
resulting in loads of extra of test time.
- The SystemTestContextBuilder was trying to fetch things from SSM. This was
addressed by adding a functools.cache on the function
- The Bedrock example dag was setting/caching the underlying conn object
globally. This was addressed by making the Airflow connection a global,
rather than the Bedrock conn. This fix is not _great_, but it does
massively
help
Before:
> 111 passed, 1 warning in 439.37s (0:07:19)
After:
> 111 passed, 1 warning in 71.76s (0:01:11)
(cherry picked from commit 102e853)
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
.../amazon/aws/system/utils/test_helpers.py | 11 ++++++--
.../aws/example_bedrock_retrieve_and_generate.py | 32 ++++++++++++----------
.../system/providers/amazon/aws/utils/__init__.py | 2 ++
3 files changed, 28 insertions(+), 17 deletions(-)
diff --git a/tests/providers/amazon/aws/system/utils/test_helpers.py
b/tests/providers/amazon/aws/system/utils/test_helpers.py
index f48de1788b7..3af3720688a 100644
--- a/tests/providers/amazon/aws/system/utils/test_helpers.py
+++ b/tests/providers/amazon/aws/system/utils/test_helpers.py
@@ -24,7 +24,7 @@ from __future__ import annotations
import os
import sys
from io import StringIO
-from unittest.mock import ANY, patch
+from unittest.mock import patch
import pytest
from moto import mock_aws
@@ -79,8 +79,15 @@ class TestAmazonSystemTestHelpers:
) -> None:
mock_getenv.return_value = env_value or ssm_value
- result = utils.fetch_variable(ANY, default_value) if default_value
else utils.fetch_variable(ANY_STR)
+ utils._fetch_from_ssm.cache_clear()
+ result = (
+ utils.fetch_variable("some_key", default_value)
+ if default_value
+ else utils.fetch_variable(ANY_STR)
+ )
+
+ utils._fetch_from_ssm.cache_clear()
assert result == expected_result
def test_fetch_variable_no_value_found_raises_exception(self):
diff --git
a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
index fcebc8c40a0..2b7bce2fecd 100644
--- a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
+++ b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py
@@ -127,7 +127,7 @@ def create_opensearch_policies(bedrock_role_arn: str,
collection_name: str, poli
def _create_security_policy(name, policy_type, policy):
try:
- aoss_client.create_security_policy(name=name,
policy=json.dumps(policy), type=policy_type)
+ aoss_client.conn.create_security_policy(name=name,
policy=json.dumps(policy), type=policy_type)
except ClientError as e:
if e.response["Error"]["Code"] == "ConflictException":
log.info("OpenSearch security policy %s already exists.", name)
@@ -135,7 +135,7 @@ def create_opensearch_policies(bedrock_role_arn: str,
collection_name: str, poli
def _create_access_policy(name, policy_type, policy):
try:
- aoss_client.create_access_policy(name=name,
policy=json.dumps(policy), type=policy_type)
+ aoss_client.conn.create_access_policy(name=name,
policy=json.dumps(policy), type=policy_type)
except ClientError as e:
if e.response["Error"]["Code"] == "ConflictException":
log.info("OpenSearch data access policy %s already exists.",
name)
@@ -204,9 +204,9 @@ def create_collection(collection_name: str):
:param collection_name: The name of the Collection to create.
"""
log.info("\nCreating collection: %s.", collection_name)
- return aoss_client.create_collection(name=collection_name,
type="VECTORSEARCH")["createCollectionDetail"][
- "id"
- ]
+ return aoss_client.conn.create_collection(name=collection_name,
type="VECTORSEARCH")[
+ "createCollectionDetail"
+ ]["id"]
@task
@@ -317,7 +317,7 @@ def get_collection_arn(collection_id: str):
"""
return next(
colxn["arn"]
- for colxn in aoss_client.list_collections()["collectionSummaries"]
+ for colxn in aoss_client.conn.list_collections()["collectionSummaries"]
if colxn["id"] == collection_id
)
@@ -336,7 +336,9 @@ def delete_data_source(knowledge_base_id: str,
data_source_id: str):
:param data_source_id: The unique identifier of the data source to delete.
"""
log.info("Deleting data source %s from Knowledge Base %s.",
data_source_id, knowledge_base_id)
- bedrock_agent_client.delete_data_source(dataSourceId=data_source_id,
knowledgeBaseId=knowledge_base_id)
+ bedrock_agent_client.conn.delete_data_source(
+ dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id
+ )
# [END howto_operator_bedrock_delete_data_source]
@@ -355,7 +357,7 @@ def delete_knowledge_base(knowledge_base_id: str):
:param knowledge_base_id: The unique identifier of the knowledge base to
delete.
"""
log.info("Deleting Knowledge Base %s.", knowledge_base_id)
-
bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=knowledge_base_id)
+
bedrock_agent_client.conn.delete_knowledge_base(knowledgeBaseId=knowledge_base_id)
# [END howto_operator_bedrock_delete_knowledge_base]
@@ -393,7 +395,7 @@ def delete_collection(collection_id: str):
:param collection_id: ID of the collection to be indexed.
"""
log.info("Deleting collection %s.", collection_id)
- aoss_client.delete_collection(id=collection_id)
+ aoss_client.conn.delete_collection(id=collection_id)
@task(trigger_rule=TriggerRule.ALL_DONE)
@@ -404,7 +406,7 @@ def delete_opensearch_policies(collection_name: str):
:param collection_name: All policies in the given collection name will be
deleted.
"""
- access_policies = aoss_client.list_access_policies(
+ access_policies = aoss_client.conn.list_access_policies(
type="data", resource=[f"collection/{collection_name}"]
)["accessPolicySummaries"]
log.info("Found access policies for %s: %s", collection_name,
access_policies)
@@ -412,10 +414,10 @@ def delete_opensearch_policies(collection_name: str):
raise Exception("No access policies found?")
for policy in access_policies:
log.info("Deleting access policy for %s: %s", collection_name,
policy["name"])
- aoss_client.delete_access_policy(name=policy["name"], type="data")
+ aoss_client.conn.delete_access_policy(name=policy["name"], type="data")
for policy_type in ["encryption", "network"]:
- policies = aoss_client.list_security_policies(
+ policies = aoss_client.conn.list_security_policies(
type=policy_type, resource=[f"collection/{collection_name}"]
)["securityPolicySummaries"]
if not policies:
@@ -423,7 +425,7 @@ def delete_opensearch_policies(collection_name: str):
log.info("Found %s security policies for %s: %s", policy_type,
collection_name, policies)
for policy in policies:
log.info("Deleting %s security policy for %s: %s", policy_type,
collection_name, policy["name"])
- aoss_client.delete_security_policy(name=policy["name"],
type=policy_type)
+ aoss_client.conn.delete_security_policy(name=policy["name"],
type=policy_type)
with DAG(
@@ -436,8 +438,8 @@ with DAG(
test_context = sys_test_context_task()
env_id = test_context["ENV_ID"]
- aoss_client = OpenSearchServerlessHook(aws_conn_id=None).conn
- bedrock_agent_client = BedrockAgentHook(aws_conn_id=None).conn
+ aoss_client = OpenSearchServerlessHook(aws_conn_id=None)
+ bedrock_agent_client = BedrockAgentHook(aws_conn_id=None)
region_name = boto3.session.Session().region_name
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py
b/tests/system/providers/amazon/aws/utils/__init__.py
index 8b4114fc90a..411f92ab7bf 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import functools
import inspect
import json
import logging
@@ -92,6 +93,7 @@ def _validate_env_id(env_id: str) -> str:
return env_id.lower()
[email protected]
def _fetch_from_ssm(key: str, test_name: str | None = None) -> str:
"""
Test values are stored in the SSM Value as a JSON-encoded dict of
key/value pairs.