This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6d075cb7b8b Update the system test
`lib/system-tests/tests/example_comprehend_document_classifier.ts` to use
example files from S3 (#44368)
6d075cb7b8b is described below
commit 6d075cb7b8b9bb36ff5bc0a667992c8d784d70d6
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 26 03:47:27 2024 -0500
Update the system test
`lib/system-tests/tests/example_comprehend_document_classifier.ts` to use
example files from S3 (#44368)
---
.../aws/example_comprehend_document_classifier.py | 127 +++++++--------------
1 file changed, 44 insertions(+), 83 deletions(-)
diff --git
a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py
b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py
index 4a103a92653..b0bf4120978 100644
---
a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py
+++
b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py
@@ -16,12 +16,10 @@
# under the License.
from __future__ import annotations
-import os
from datetime import datetime
-from airflow import DAG, settings
+from airflow import DAG
from airflow.decorators import task, task_group
-from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
from airflow.providers.amazon.aws.operators.comprehend import (
@@ -36,31 +34,27 @@ from airflow.providers.amazon.aws.operators.s3 import (
from airflow.providers.amazon.aws.sensors.comprehend import (
ComprehendCreateDocumentClassifierCompletedSensor,
)
-from airflow.providers.amazon.aws.transfers.http_to_s3 import HttpToS3Operator
from airflow.utils.trigger_rule import TriggerRule
from providers.tests.system.amazon.aws.utils import SystemTestContextBuilder
ROLE_ARN_KEY = "ROLE_ARN"
-sys_test_context_task =
SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+BUCKET_NAME_KEY = "BUCKET_NAME"
+BUCKET_KEY_DISCHARGE_KEY = "BUCKET_KEY_DISCHARGE"
+BUCKET_KEY_DOCTORS_NOTES = "BUCKET_KEY_DOCTORS_NOTES"
+sys_test_context_task = (
+ SystemTestContextBuilder()
+ .add_variable(ROLE_ARN_KEY)
+ .add_variable(BUCKET_NAME_KEY)
+ .add_variable(BUCKET_KEY_DISCHARGE_KEY)
+ .add_variable(BUCKET_KEY_DOCTORS_NOTES)
+ .build()
+)
DAG_ID = "example_comprehend_document_classifier"
ANNOTATION_BUCKET_KEY = "training-labels/label.csv"
TRAINING_DATA_PREFIX = "training-docs"
-# To create a custom document classifier, we need a minimum of 10 documents
for each label.
-# for testing purpose, we will generate 10 copies of each document referenced
below.
-PUBLIC_DATA_SOURCES = [
- {
- "fileName": "discharge-summary.pdf",
- "endpoint":
"aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/discharge-summary.pdf?raw=true",
- },
- {
- "fileName": "doctors-notes.pdf",
- "endpoint":
"aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/doctors-notes.pdf?raw=true",
- },
-]
-
# Annotations file won't allow headers
# label,document name,page number
@@ -119,74 +113,27 @@ def document_classifier_workflow():
)
-@task_group
-def copy_data_to_s3(bucket: str, sources: list[dict], prefix: str,
number_of_copies=1):
- """
-
- Copy some sample data to S3 using HttpToS3Operator.
-
- :param bucket: Name of the Amazon S3 bucket to send the data.
- :param prefix: Folder to store the files
- :param number_of_copies: Number of files to create for a document from the
sources
- :param sources: Public available data locations
- """
-
- """
- EX: If number_of_copies is 2, sources has file name 'file.pdf', and prefix
is 'training-docs'.
- Will generate two copies and upload to s3:
- - training-docs/file-0.pdf
- - training-docs/file-1.pdf
- """
-
- http_to_s3_configs = [
+@task
+def create_kwargs_discharge():
+ return [
{
- "endpoint": source["endpoint"],
- "s3_key":
f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
+ "source_bucket_key": str(test_context[BUCKET_KEY_DISCHARGE_KEY]),
+ "dest_bucket_key":
f"{TRAINING_DATA_PREFIX}/discharge-summary-{counter}.pdf",
}
- for source in sources
+ for counter in range(10)
]
- copy_to_s3_configs = [
+
+
+@task
+def create_kwargs_doctors_notes():
+ return [
{
- "source_bucket_key":
f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}",
- "dest_bucket_key":
f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}",
+ "source_bucket_key": str(test_context[BUCKET_KEY_DOCTORS_NOTES]),
+ "dest_bucket_key":
f"{TRAINING_DATA_PREFIX}/doctors-notes-{counter}.pdf",
}
- for counter in range(number_of_copies)
- for source in sources
+ for counter in range(10)
]
- @task
- def create_connection(conn_id):
- conn = Connection(
- conn_id=conn_id,
- conn_type="http",
- host="https://github.com/",
- )
- session = settings.Session()
- session.add(conn)
- session.commit()
-
- @task(trigger_rule=TriggerRule.ALL_DONE)
- def delete_connection(conn_id):
- session = settings.Session()
- conn_to_details = session.query(Connection).filter(Connection.conn_id
== conn_id).first()
- session.delete(conn_to_details)
- session.commit()
-
- http_to_s3_task = HttpToS3Operator.partial(
- task_id="http_to_s3_task",
- http_conn_id=http_conn_id,
- s3_bucket=bucket,
- ).expand_kwargs(http_to_s3_configs)
-
- s3_copy_task = S3CopyObjectOperator.partial(
- task_id="s3_copy_task",
- source_bucket_name=bucket,
- dest_bucket_name=bucket,
- meta_data_directive="REPLACE",
- ).expand_kwargs(copy_to_s3_configs)
-
- chain(create_connection(http_conn_id), http_to_s3_task, s3_copy_task,
delete_connection(http_conn_id))
-
with DAG(
dag_id=DAG_ID,
@@ -199,7 +146,6 @@ with DAG(
env_id = test_context["ENV_ID"]
classifier_name = f"{env_id}-custom-document-classifier"
bucket_name = f"{env_id}-comprehend-document-classifier"
- http_conn_id = f"{env_id}-git"
input_data_configurations = {
"S3Uri": f"s3://{bucket_name}/{ANNOTATION_BUCKET_KEY}",
@@ -219,6 +165,22 @@ with DAG(
bucket_name=bucket_name,
)
+ discharge_kwargs = create_kwargs_discharge()
+ s3_copy_discharge_task = S3CopyObjectOperator.partial(
+ task_id="s3_copy_discharge_task",
+ source_bucket_name=test_context[BUCKET_NAME_KEY],
+ dest_bucket_name=bucket_name,
+ meta_data_directive="REPLACE",
+ ).expand_kwargs(discharge_kwargs)
+
+ doctors_notes_kwargs = create_kwargs_doctors_notes()
+ s3_copy_doctors_notes_task = S3CopyObjectOperator.partial(
+ task_id="s3_copy_doctors_notes_task",
+ source_bucket_name=test_context[BUCKET_NAME_KEY],
+ dest_bucket_name=bucket_name,
+ meta_data_directive="REPLACE",
+ ).expand_kwargs(doctors_notes_kwargs)
+
upload_annotation_file = S3CreateObjectOperator(
task_id="upload_annotation_file",
s3_bucket=bucket_name,
@@ -236,10 +198,9 @@ with DAG(
chain(
test_context,
create_bucket,
+ s3_copy_discharge_task,
+ s3_copy_doctors_notes_task,
upload_annotation_file,
- copy_data_to_s3(
- bucket=bucket_name, sources=PUBLIC_DATA_SOURCES,
prefix=TRAINING_DATA_PREFIX, number_of_copies=10
- ),
# TEST BODY
document_classifier_workflow(),
# TEST TEARDOWN