This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5983506df3 Add operator to invoke Azure-Synapse pipeline (#35091)
5983506df3 is described below
commit 5983506df370325f7b23a182798341d17d091a32
Author: ambika-garg <[email protected]>
AuthorDate: Thu Nov 16 04:51:42 2023 -0500
Add operator to invoke Azure-Synapse pipeline (#35091)
* Update to resolve rebase conflicts and pass pre-commit hooks
* Feature: Add Azure Synapse Pipeline run operator in Microsoft Provider
* Add a hook to interact with Azure Synapse Analytics
* Add a operator to trigger Synapse pipeline from DAG and operator
link
* Add unit tests for operator and hook
* Update provider.yaml to support new operator, operator link and
hook
* Update provider_dependencies to install azure-synapse-artifacts
* Add spellings to resolve build docs test
* Fix: Pytest Tests
* Add Mock Synapse Workspace URL
* Set Default wait_for_termination to False
* Rename files as per standards
* Move AzureSynapsePipelineHook class to synapse.py for ease of find
* Fix all imports for the class
* Remove the file from provider.yaml
* Delete the synapse_pipeline.py file
---------
Co-authored-by: Ambika Garg <[email protected]>
---
airflow/providers/microsoft/azure/hooks/synapse.py | 191 +++++++++++++++++-
.../providers/microsoft/azure/operators/synapse.py | 187 +++++++++++++++++-
airflow/providers/microsoft/azure/provider.yaml | 8 +-
docs/spelling_wordlist.txt | 2 +
generated/provider_dependencies.json | 1 +
.../microsoft/azure/hooks/test_synapse_pipeline.py | 159 +++++++++++++++
.../microsoft/azure/operators/test_synapse.py | 218 ++++++++++++++++++++-
.../azure/example_synapse_run_pipeline.py | 59 ++++++
8 files changed, 813 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py
b/airflow/providers/microsoft/azure/hooks/synapse.py
index e284194376..d48109d694 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -19,10 +19,12 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any, Union
+from azure.core.exceptions import ServiceRequestError
from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.synapse.artifacts import ArtifactsClient
from azure.synapse.spark import SparkClient
-from airflow.exceptions import AirflowTaskTimeout
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
add_managed_identity_connection_widgets,
@@ -31,6 +33,7 @@ from airflow.providers.microsoft.azure.utils import (
)
if TYPE_CHECKING:
+ from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
from azure.synapse.spark.models import SparkBatchJobOptions
Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
@@ -217,3 +220,189 @@ class AzureSynapseHook(BaseHook):
:param job_id: The synapse spark job identifier.
"""
self.get_conn().spark_batch.cancel_spark_batch_job(job_id)
+
+
+class AzureSynapsePipelineRunStatus:
+ """Azure Synapse pipeline operation statuses."""
+
+ QUEUED = "Queued"
+ IN_PROGRESS = "InProgress"
+ SUCCEEDED = "Succeeded"
+ FAILED = "Failed"
+ CANCELING = "Canceling"
+ CANCELLED = "Cancelled"
+ TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED}
+ INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING}
+ FAILURE_STATES = {FAILED, CANCELLED}
+
+
+class AzureSynapsePipelineRunException(AirflowException):
+ """An exception that indicates a pipeline run failed to complete."""
+
+
+class AzureSynapsePipelineHook(BaseHook):
+ """
+ A hook to interact with Azure Synapse Pipeline.
+
+ :param azure_synapse_conn_id: The :ref:`Azure Synapse connection
id<howto/connection:synapse>`.
+ :param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace
development endpoint.
+ """
+
+ conn_type: str = "azure_synapse_pipeline"
+ conn_name_attr: str = "azure_synapse_conn_id"
+ default_conn_name: str = "azure_synapse_connection"
+ hook_name: str = "Azure Synapse Pipeline"
+
+ @staticmethod
+ def get_connection_form_widgets() -> dict[str, Any]:
+ """Returns connection widgets to add to connection form."""
+ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+ from flask_babel import lazy_gettext
+ from wtforms import StringField
+
+ return {
+ "tenantId": StringField(lazy_gettext("Tenant ID"),
widget=BS3TextFieldWidget()),
+ }
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Returns custom field behaviour."""
+ return {
+ "hidden_fields": ["schema", "port", "extra"],
+ "relabeling": {"login": "Client ID", "password": "Secret", "host":
"Synapse Workspace URL"},
+ }
+
+ def __init__(
+ self, azure_synapse_workspace_dev_endpoint: str,
azure_synapse_conn_id: str = default_conn_name
+ ):
+ self._conn = None
+ self.conn_id = azure_synapse_conn_id
+ self.azure_synapse_workspace_dev_endpoint =
azure_synapse_workspace_dev_endpoint
+ super().__init__()
+
+ def _get_field(self, extras, name):
+ return get_field(
+ conn_id=self.conn_id,
+ conn_type=self.conn_type,
+ extras=extras,
+ field_name=name,
+ )
+
+ def get_conn(self) -> ArtifactsClient:
+ if self._conn is not None:
+ return self._conn
+
+ conn = self.get_connection(self.conn_id)
+ extras = conn.extra_dejson
+ tenant = self._get_field(extras, "tenantId")
+
+ credential: Credentials
+ if conn.login is not None and conn.password is not None:
+ if not tenant:
+ raise ValueError("A Tenant ID is required when authenticating
with Client ID and Secret.")
+
+ credential = ClientSecretCredential(
+ client_id=conn.login, client_secret=conn.password,
tenant_id=tenant
+ )
+ else:
+ credential = DefaultAzureCredential()
+ self._conn = self._create_client(credential,
self.azure_synapse_workspace_dev_endpoint)
+
+ if self._conn is not None:
+ return self._conn
+ else:
+ raise ValueError("Failed to create ArtifactsClient")
+
+ @staticmethod
+ def _create_client(credential: Credentials, endpoint: str):
+ return ArtifactsClient(credential=credential, endpoint=endpoint)
+
+ def run_pipeline(self, pipeline_name: str, **config: Any) ->
CreateRunResponse:
+ """
+ Run a Synapse pipeline.
+
+ :param pipeline_name: The pipeline name.
+ :param config: Extra parameters for the Synapse Artifact Client.
+ :return: The pipeline run Id.
+ """
+ return self.get_conn().pipeline.create_pipeline_run(pipeline_name,
**config)
+
+ def get_pipeline_run(self, run_id: str) -> PipelineRun:
+ """
+ Get the pipeline run.
+
+ :param run_id: The pipeline run identifier.
+ :return: The pipeline run.
+ """
+ return self.get_conn().pipeline_run.get_pipeline_run(run_id=run_id)
+
+ def get_pipeline_run_status(self, run_id: str) -> str:
+ """
+ Get a pipeline run's current status.
+
+ :param run_id: The pipeline run identifier.
+
+ :return: The status of the pipeline run.
+ """
+ pipeline_run_status = self.get_pipeline_run(
+ run_id=run_id,
+ ).status
+
+ return str(pipeline_run_status)
+
+ def refresh_conn(self) -> ArtifactsClient:
+ self._conn = None
+ return self.get_conn()
+
+ def wait_for_pipeline_run_status(
+ self,
+ run_id: str,
+ expected_statuses: str | set[str],
+ check_interval: int = 60,
+ timeout: int = 60 * 60 * 24 * 7,
+ ) -> bool:
+ """
+ Waits for a pipeline run to match an expected status.
+
+ :param run_id: The pipeline run identifier.
+ :param expected_statuses: The desired status(es) to check against a
pipeline run's current status.
+ :param check_interval: Time in seconds to check on a pipeline run's
status.
+ :param timeout: Time in seconds to wait for a pipeline to reach a
terminal status or the expected
+ status.
+
+ :return: Boolean indicating if the pipeline run has reached the
``expected_status``.
+ """
+ pipeline_run_status = self.get_pipeline_run_status(run_id=run_id)
+ executed_after_token_refresh = True
+ start_time = time.monotonic()
+
+ while (
+ pipeline_run_status not in
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES
+ and pipeline_run_status not in expected_statuses
+ ):
+ if start_time + timeout < time.monotonic():
+ raise AzureSynapsePipelineRunException(
+ f"Pipeline run {run_id} has not reached a terminal status
after {timeout} seconds."
+ )
+
+ # Wait to check the status of the pipeline run based on the
``check_interval`` configured.
+ time.sleep(check_interval)
+
+ try:
+ pipeline_run_status =
self.get_pipeline_run_status(run_id=run_id)
+ executed_after_token_refresh = True
+ except ServiceRequestError:
+ if executed_after_token_refresh:
+ self.refresh_conn()
+ else:
+ raise
+
+ return pipeline_run_status in expected_statuses
+
+ def cancel_run_pipeline(self, run_id: str) -> None:
+ """
+ Cancel the pipeline run.
+
+ :param run_id: The pipeline run identifier.
+ """
+ self.get_conn().pipeline_run.cancel_pipeline_run(run_id)
diff --git a/airflow/providers/microsoft/azure/operators/synapse.py
b/airflow/providers/microsoft/azure/operators/synapse.py
index e7fde11528..f7a23d5f09 100644
--- a/airflow/providers/microsoft/azure/operators/synapse.py
+++ b/airflow/providers/microsoft/azure/operators/synapse.py
@@ -17,14 +17,24 @@
from __future__ import annotations
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from urllib.parse import urlencode
-from airflow.models import BaseOperator
-from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook,
AzureSynapseSparkBatchRunStatus
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.models import BaseOperator, BaseOperatorLink, XCom
+from airflow.providers.microsoft.azure.hooks.synapse import (
+ AzureSynapseHook,
+ AzureSynapsePipelineHook,
+ AzureSynapsePipelineRunException,
+ AzureSynapsePipelineRunStatus,
+ AzureSynapseSparkBatchRunStatus,
+)
if TYPE_CHECKING:
from azure.synapse.spark.models import SparkBatchJobOptions
+ from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
@@ -108,3 +118,174 @@ class AzureSynapseRunSparkBatchOperator(BaseOperator):
job_id=self.job_id,
)
self.log.info("Job run %s has been cancelled successfully.",
self.job_id)
+
+
+class AzureSynapsePipelineRunLink(BaseOperatorLink):
+ """Constructs a link to monitor a pipeline run in Azure Synapse."""
+
+ name = "Monitor Pipeline Run"
+
+ def get_fields_from_url(self, workspace_url):
+ """
+ Extracts the workspace_name, subscription_id and resource_group from
the Synapse workspace url.
+
+ :param workspace_url: The workspace url.
+ """
+ import re
+ from urllib.parse import unquote, urlparse
+
+ pattern = r"https://web\.azuresynapse\.net\?workspace=(.*)"
+ match = re.search(pattern, workspace_url)
+
+ if not match:
+ raise ValueError("Invalid workspace URL format")
+
+ extracted_text = match.group(1)
+ parsed_url = urlparse(extracted_text)
+ path = unquote(parsed_url.path)
+ path_segments = path.split("/")
+ if len(path_segments) < 5:
+ raise
+
+ return {
+ "workspace_name": path_segments[-1],
+ "subscription_id": path_segments[2],
+ "resource_group": path_segments[4],
+ }
+
+ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
+ run_id = XCom.get_value(key="run_id", ti_key=ti_key) or ""
+ conn_id = operator.azure_synapse_conn_id # type: ignore
+ conn = BaseHook.get_connection(conn_id)
+ self.synapse_workspace_url = conn.host
+
+ fields = self.get_fields_from_url(self.synapse_workspace_url)
+
+ params = {
+ "workspace": f"/subscriptions/{fields['subscription_id']}"
+
f"/resourceGroups/{fields['resource_group']}/providers/Microsoft.Synapse"
+ f"/workspaces/{fields['workspace_name']}",
+ }
+ encoded_params = urlencode(params)
+ base_url =
f"https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}?"
+
+ return base_url + encoded_params
+
+
+class AzureSynapseRunPipelineOperator(BaseOperator):
+ """
+ Executes a Synapse Pipeline.
+
+ :param pipeline_name: The name of the pipeline to execute.
+ :param azure_synapse_conn_id: The Airflow connection ID for Azure Synapse.
+ :param azure_synapse_workspace_dev_endpoint: The Azure Synapse workspace
development endpoint.
+ :param wait_for_termination: Flag to wait on a pipeline run's termination.
+ :param reference_pipeline_run_id: The pipeline run identifier. If this run
ID is specified the parameters
+ of the specified run will be used to create a new run.
+ :param is_recovery: Recovery mode flag. If recovery mode is set to `True`,
the specified referenced
+ pipeline run and the new run will be grouped under the same
``groupId``.
+ :param start_activity_name: In recovery mode, the rerun will start from
this activity. If not specified,
+ all activities will run.
+ :param parameters: Parameters of the pipeline run. These parameters are
referenced in a pipeline via
+ ``@pipeline().parameters.parameterName`` and will be used only if the
``reference_pipeline_run_id`` is
+ not specified.
+ :param timeout: Time in seconds to wait for a pipeline to reach a terminal
status for non-asynchronous
+ waits. Used only if ``wait_for_termination`` is True.
+ :param check_interval: Time in seconds to check on a pipeline run's status
for non-asynchronous waits.
+ Used only if ``wait_for_termination`` is True.
+
+ """
+
+ template_fields: Sequence[str] = ("azure_synapse_conn_id",)
+
+ operator_extra_links = (AzureSynapsePipelineRunLink(),)
+
+ def __init__(
+ self,
+ pipeline_name: str,
+ azure_synapse_conn_id: str,
+ azure_synapse_workspace_dev_endpoint: str,
+ wait_for_termination: bool = True,
+ reference_pipeline_run_id: str | None = None,
+ is_recovery: bool | None = None,
+ start_activity_name: str | None = None,
+ parameters: dict[str, Any] | None = None,
+ timeout: int = 60 * 60 * 24 * 7,
+ check_interval: int = 60,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.azure_synapse_conn_id = azure_synapse_conn_id
+ self.pipeline_name = pipeline_name
+ self.azure_synapse_workspace_dev_endpoint =
azure_synapse_workspace_dev_endpoint
+ self.wait_for_termination = wait_for_termination
+ self.reference_pipeline_run_id = reference_pipeline_run_id
+ self.is_recovery = is_recovery
+ self.start_activity_name = start_activity_name
+ self.parameters = parameters
+ self.timeout = timeout
+ self.check_interval = check_interval
+
+ @cached_property
+ def hook(self):
+ """Create and return an AzureSynapsePipelineHook (cached)."""
+ return AzureSynapsePipelineHook(
+ azure_synapse_conn_id=self.azure_synapse_conn_id,
+
azure_synapse_workspace_dev_endpoint=self.azure_synapse_workspace_dev_endpoint,
+ )
+
+ def execute(self, context) -> None:
+ self.log.info("Executing the %s pipeline.", self.pipeline_name)
+ response = self.hook.run_pipeline(
+ pipeline_name=self.pipeline_name,
+ reference_pipeline_run_id=self.reference_pipeline_run_id,
+ is_recovery=self.is_recovery,
+ start_activity_name=self.start_activity_name,
+ parameters=self.parameters,
+ )
+ self.run_id = vars(response)["run_id"]
+ # Push the ``run_id`` value to XCom regardless of what happens during
execution. This allows for
+ # retrieval the executed pipeline's ``run_id`` for downstream tasks
especially if performing an
+ # asynchronous wait.
+ context["ti"].xcom_push(key="run_id", value=self.run_id)
+
+ if self.wait_for_termination:
+ self.log.info("Waiting for pipeline run %s to terminate.",
self.run_id)
+
+ if self.hook.wait_for_pipeline_run_status(
+ run_id=self.run_id,
+ expected_statuses=AzureSynapsePipelineRunStatus.SUCCEEDED,
+ check_interval=self.check_interval,
+ timeout=self.timeout,
+ ):
+ self.log.info("Pipeline run %s has completed successfully.",
self.run_id)
+ else:
+ raise AzureSynapsePipelineRunException(
+ f"Pipeline run {self.run_id} has failed or has been
cancelled."
+ )
+
+ def execute_complete(self, event: dict[str, str]) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+
+ Relies on trigger to throw an exception, otherwise it assumes
execution was successful.
+ """
+ if event:
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(event["message"])
+
+ def on_kill(self) -> None:
+ if self.run_id:
+ self.hook.cancel_run_pipeline(run_id=self.run_id)
+
+ # Check to ensure the pipeline run was cancelled as expected.
+ if self.hook.wait_for_pipeline_run_status(
+ run_id=self.run_id,
+ expected_statuses=AzureSynapsePipelineRunStatus.CANCELLED,
+ check_interval=self.check_interval,
+ timeout=self.timeout,
+ ):
+ self.log.info("Pipeline run %s has been cancelled
successfully.", self.run_id)
+ else:
+ raise AzureSynapsePipelineRunException(f"Pipeline run
{self.run_id} was not cancelled.")
diff --git a/airflow/providers/microsoft/azure/provider.yaml
b/airflow/providers/microsoft/azure/provider.yaml
index 975feb276a..53bccda6c7 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -82,6 +82,7 @@ dependencies:
- azure-storage-file-share
- azure-servicebus>=7.6.1
- azure-synapse-spark
+ - azure-synapse-artifacts>=0.17.0
- adal>=1.2.7
- azure-storage-file-datalake>=12.9.1
- azure-kusto-data>=4.1.0
@@ -279,15 +280,13 @@ connection-types:
connection-type: azure_fileshare
- hook-class-name:
airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook
connection-type: azure_container_volume
- - hook-class-name: >-
-
airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook
+ - hook-class-name:
airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook
connection-type: azure_container_instance
- hook-class-name: airflow.providers.microsoft.azure.hooks.wasb.WasbHook
connection-type: wasb
- hook-class-name:
airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook
connection-type: azure_data_factory
- - hook-class-name: >-
-
airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook
+ - hook-class-name:
airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook
connection-type: azure_container_registry
- hook-class-name:
airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook
connection-type: azure_service_bus
@@ -304,6 +303,7 @@ logging:
extra-links:
-
airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryPipelineRunLink
+ -
airflow.providers.microsoft.azure.operators.synapse.AzureSynapsePipelineRunLink
config:
azure_remote_logging:
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 3979ae3c59..b787191fd0 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -315,6 +315,7 @@ cpus
crd
createDisposition
CreateQueryOperator
+CreateRunResponse
creationTimestamp
credssp
Cron
@@ -1144,6 +1145,7 @@ pinecone
pinodb
Pinot
pinot
+PipelineRun
pkill
plaintext
platformVersion
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 0b234b4e12..058078a191 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -586,6 +586,7 @@
"azure-storage-blob>=12.14.0",
"azure-storage-file-datalake>=12.9.1",
"azure-storage-file-share",
+ "azure-synapse-artifacts>=0.17.0",
"azure-synapse-spark"
],
"cross-providers-deps": [
diff --git a/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py
b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py
new file mode 100644
index 0000000000..d0309bd0a6
--- /dev/null
+++ b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.synapse.artifacts import ArtifactsClient
+
+from airflow.models.connection import Connection
+from airflow.providers.microsoft.azure.hooks.synapse import (
+ AzureSynapsePipelineHook,
+ AzureSynapsePipelineRunException,
+ AzureSynapsePipelineRunStatus,
+)
+
+DEFAULT_CONNECTION_CLIENT_SECRET = "azure_synapse_test_client_secret"
+DEFAULT_CONNECTION_DEFAULT_CREDENTIAL = "azure_synapse_test_default_credential"
+
+SYNAPSE_WORKSPACE_URL = "synapse_workspace_url"
+AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint"
+PIPELINE_NAME = "pipeline_name"
+RUN_ID = "run_id"
+
+
[email protected](autouse=True)
+def setup_connections(create_mock_connections):
+ create_mock_connections(
+ # connection_client_secret
+ Connection(
+ conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+ conn_type="azure_synapse_pipeline",
+ host=SYNAPSE_WORKSPACE_URL,
+ login="clientId",
+ password="clientSecret",
+ extra={"tenantId": "tenantId"},
+ ),
+ # connection_default_credential
+ Connection(
+ conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+ conn_type="azure_synapse_pipeline",
+ host=SYNAPSE_WORKSPACE_URL,
+ extra={},
+ ),
+ # connection_missing_tenant_id
+ Connection(
+ conn_id="azure_synapse_missing_tenant_id",
+ conn_type="azure_synapse_pipeline",
+ host=SYNAPSE_WORKSPACE_URL,
+ login="clientId",
+ password="clientSecret",
+ extra={},
+ ),
+ )
+
+
[email protected]
+def hook():
+ client = AzureSynapsePipelineHook(
+ azure_synapse_conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+ )
+ client._conn = MagicMock(spec=["pipeline_run", "pipeline"])
+
+ return client
+
+
[email protected](
+ ("connection_id", "credential_type"),
+ [
+ (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
+ (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
+ ],
+)
+def test_get_connection_by_credential_client_secret(connection_id: str,
credential_type: type):
+ hook = AzureSynapsePipelineHook(
+ azure_synapse_conn_id=connection_id,
+
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+ )
+
+ with patch.object(hook, "_create_client") as mock_create_client:
+ mock_create_client.return_value = MagicMock()
+ connection = hook.get_conn()
+ assert connection is not None
+ mock_create_client.assert_called_once()
+ assert isinstance(mock_create_client.call_args.args[0],
credential_type)
+ assert mock_create_client.call_args.args[1] ==
AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT
+
+
+def test_run_pipeline(hook: AzureSynapsePipelineHook):
+ hook.run_pipeline(PIPELINE_NAME)
+
+ if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+
hook._conn.pipeline.create_pipeline_run.assert_called_with(PIPELINE_NAME)
+
+
+def test_get_pipeline_run(hook: AzureSynapsePipelineHook):
+ hook.get_pipeline_run(run_id=RUN_ID)
+
+ if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+
hook._conn.pipeline_run.get_pipeline_run.assert_called_with(run_id=RUN_ID)
+
+
+def test_cancel_run_pipeline(hook: AzureSynapsePipelineHook):
+ hook.cancel_run_pipeline(RUN_ID)
+
+ if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+ hook._conn.pipeline_run.cancel_pipeline_run.assert_called_with(RUN_ID)
+
+
+_wait_for_pipeline_run_status_test_args = [
+ (AzureSynapsePipelineRunStatus.SUCCEEDED,
AzureSynapsePipelineRunStatus.SUCCEEDED, True),
+ (AzureSynapsePipelineRunStatus.FAILED,
AzureSynapsePipelineRunStatus.SUCCEEDED, False),
+ (AzureSynapsePipelineRunStatus.CANCELLED,
AzureSynapsePipelineRunStatus.SUCCEEDED, False),
+ (AzureSynapsePipelineRunStatus.IN_PROGRESS,
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+ (AzureSynapsePipelineRunStatus.QUEUED,
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+ (AzureSynapsePipelineRunStatus.CANCELING,
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+ (AzureSynapsePipelineRunStatus.SUCCEEDED,
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+ (AzureSynapsePipelineRunStatus.FAILED,
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+ (AzureSynapsePipelineRunStatus.CANCELLED,
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+]
+
+
[email protected](
+ argnames=("pipeline_run_status", "expected_status", "expected_output"),
+ argvalues=_wait_for_pipeline_run_status_test_args,
+ ids=[
+ f"run_status_{argval[0]}_expected_{argval[1]}"
+ if isinstance(argval[1], str)
+ else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+ for argval in _wait_for_pipeline_run_status_test_args
+ ],
+)
+def test_wait_for_pipeline_run_status(hook, pipeline_run_status,
expected_status, expected_output):
+ config = {"run_id": RUN_ID, "timeout": 3, "check_interval": 1,
"expected_statuses": expected_status}
+
+ with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as
mock_pipeline_run:
+ mock_pipeline_run.return_value.status = pipeline_run_status
+
+ if expected_output != "timeout":
+ assert hook.wait_for_pipeline_run_status(**config) ==
expected_output
+ else:
+ with pytest.raises(AzureSynapsePipelineRunException):
+ hook.wait_for_pipeline_run_status(**config)
diff --git a/tests/providers/microsoft/azure/operators/test_synapse.py
b/tests/providers/microsoft/azure/operators/test_synapse.py
index 233e1c57fd..14ffd783ba 100644
--- a/tests/providers/microsoft/azure/operators/test_synapse.py
+++ b/tests/providers/microsoft/azure/operators/test_synapse.py
@@ -17,24 +17,42 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
import pytest
from airflow.models import Connection
-from airflow.providers.microsoft.azure.operators.synapse import
AzureSynapseRunSparkBatchOperator
+from airflow.providers.microsoft.azure.hooks.synapse import (
+ AzureSynapsePipelineHook,
+ AzureSynapsePipelineRunException,
+ AzureSynapsePipelineRunStatus,
+)
+from airflow.providers.microsoft.azure.operators.synapse import (
+ AzureSynapsePipelineRunLink,
+ AzureSynapseRunPipelineOperator,
+ AzureSynapseRunSparkBatchOperator,
+)
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
-SUBSCRIPTION_ID = "my-subscription-id"
+SUBSCRIPTION_ID = "subscription_id"
+TENANT_ID = "tenant_id"
TASK_ID = "run_spark_op"
+AZURE_SYNAPSE_PIPELINE_TASK_ID = "run_pipeline_op"
AZURE_SYNAPSE_CONN_ID = "azure_synapse_test"
CONN_EXTRAS = {
"synapse__subscriptionId": SUBSCRIPTION_ID,
"synapse__tenantId": "my-tenant-id",
"synapse__spark_pool": "my-spark-pool",
}
+SYNAPSE_PIPELINE_CONN_EXTRAS = {"tenantId": TENANT_ID}
JOB_RUN_RESPONSE = {"id": 123}
+PIPELINE_NAME = "Pipeline 1"
+AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint"
+RESOURCE_GROUP = "op-resource-group"
+WORKSPACE_NAME = "workspace-test"
+AZURE_SYNAPSE_WORKSPACE_URL =
f"https://web.azuresynapse.net?workspace=%2fsubscriptions%{SUBSCRIPTION_ID}%2fresourceGroups%2f{RESOURCE_GROUP}%2fproviders%2fMicrosoft.Synapse%2fworkspaces%2f{WORKSPACE_NAME}"
+PIPELINE_RUN_RESPONSE = {"run_id": "run_id"}
class TestAzureSynapseRunSparkBatchOperator:
@@ -53,7 +71,7 @@ class TestAzureSynapseRunSparkBatchOperator:
create_mock_connection(
Connection(
conn_id=AZURE_SYNAPSE_CONN_ID,
- conn_type="azure_synapse",
+ conn_type="azure_synapse_pipeline",
host="https://synapsetest.net",
login="client-id",
password="client-secret",
@@ -102,3 +120,195 @@ class TestAzureSynapseRunSparkBatchOperator:
op.execute(context=self.mock_context)
op.on_kill()
mock_cancel_job_run.assert_called_once_with(job_id=JOB_RUN_RESPONSE["id"])
+
+
+class TestAzureSynapseRunPipelineOperator:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, create_mock_connection):
+ self.mock_ti = MagicMock()
+ self.mock_context = {"ti": self.mock_ti}
+ self.config = {
+ "task_id": AZURE_SYNAPSE_PIPELINE_TASK_ID,
+ "azure_synapse_conn_id": AZURE_SYNAPSE_CONN_ID,
+ "pipeline_name": PIPELINE_NAME,
+ "azure_synapse_workspace_dev_endpoint":
AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+ "check_interval": 1,
+ "timeout": 3,
+ }
+
+ create_mock_connection(
+ Connection(
+ conn_id=AZURE_SYNAPSE_CONN_ID,
+ conn_type="azure_synapse_pipeline",
+ host=AZURE_SYNAPSE_WORKSPACE_URL,
+ login="client_id",
+ password="client_secret",
+ extra=SYNAPSE_PIPELINE_CONN_EXTRAS,
+ )
+ )
+
+ @staticmethod
+ def create_pipeline_run(status: str):
+ """Helper function to create a mock pipeline run with a given
execution status."""
+
+ run = MagicMock()
+ run.status = status
+
+ return run
+
+ @patch.object(AzureSynapsePipelineHook, "run_pipeline",
return_value=MagicMock(**PIPELINE_RUN_RESPONSE))
+ @pytest.mark.parametrize(
+ "pipeline_run_status,expected_output",
+ [
+ (AzureSynapsePipelineRunStatus.SUCCEEDED, None),
+ (AzureSynapsePipelineRunStatus.FAILED, "exception"),
+ (AzureSynapsePipelineRunStatus.CANCELLED, "exception"),
+ (AzureSynapsePipelineRunStatus.IN_PROGRESS, "timeout"),
+ (AzureSynapsePipelineRunStatus.QUEUED, "timeout"),
+ (AzureSynapsePipelineRunStatus.CANCELING, "timeout"),
+ ],
+ )
+ def test_execute_wait_for_termination(self, mock_run_pipeline,
pipeline_run_status, expected_output):
+ # Initialize the operator with mock config, (**) unpacks the config
dict.
+ operator = AzureSynapseRunPipelineOperator(**self.config)
+
+ assert operator.azure_synapse_conn_id ==
self.config["azure_synapse_conn_id"]
+ assert operator.pipeline_name == self.config["pipeline_name"]
+ assert (
+ operator.azure_synapse_workspace_dev_endpoint
+ == self.config["azure_synapse_workspace_dev_endpoint"]
+ )
+ assert operator.check_interval == self.config["check_interval"]
+ assert operator.timeout == self.config["timeout"]
+ assert operator.wait_for_termination
+
+ with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as
mock_get_pipeline_run:
+ mock_get_pipeline_run.return_value =
TestAzureSynapseRunPipelineOperator.create_pipeline_run(
+ pipeline_run_status
+ )
+
+ if not expected_output:
+ # A successful operator execution should not return any values.
+ assert not operator.execute(context=self.mock_context)
+ elif expected_output == "exception":
+ # The operator should fail if the pipeline run fails or is
canceled.
+ with pytest.raises(
+ AzureSynapsePipelineRunException,
+ match=f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has
failed or has been cancelled.",
+ ):
+ operator.execute(context=self.mock_context)
+ else:
+ # Demonstrating the operator timing out after surpassing the
configured timeout value.
+ with pytest.raises(
+ AzureSynapsePipelineRunException,
+ match=(
+ f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has
not reached a terminal status "
+ f"after {self.config['timeout']} seconds."
+ ),
+ ):
+ operator.execute(context=self.mock_context)
+
+ # Check the ``run_id`` attr is assigned after executing the
pipeline.
+ assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"]
+
+ # Check to ensure an `XCom` is pushed regardless of pipeline run
result.
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]
+ )
+
+ # Check if mock_run_pipeline called with particular set of
arguments.
+ mock_run_pipeline.assert_called_once_with(
+ pipeline_name=self.config["pipeline_name"],
+ reference_pipeline_run_id=None,
+ is_recovery=None,
+ start_activity_name=None,
+ parameters=None,
+ )
+
+ if pipeline_run_status in
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES:
+
mock_get_pipeline_run.assert_called_once_with(run_id=mock_run_pipeline.return_value.run_id)
+ else:
+ # When the pipeline run status is not in a terminal status or
"Succeeded", the operator will
+ # continue to call ``get_pipeline_run()`` until a ``timeout``
number of seconds has passed
+ # (3 seconds for this test). Therefore, there should be 4
calls of this function: one
+ # initially and 3 for each check done at a 1 second interval.
+ assert mock_get_pipeline_run.call_count == 4
+
+
mock_get_pipeline_run.assert_called_with(run_id=mock_run_pipeline.return_value.run_id)
+
+ @patch.object(AzureSynapsePipelineHook, "run_pipeline",
return_value=MagicMock(**PIPELINE_RUN_RESPONSE))
+ def test_execute_no_wait_for_termination(self, mock_run_pipeline):
+ operator = AzureSynapseRunPipelineOperator(wait_for_termination=False,
**self.config)
+
+ assert operator.azure_synapse_conn_id ==
self.config["azure_synapse_conn_id"]
+ assert operator.pipeline_name == self.config["pipeline_name"]
+ assert (
+ operator.azure_synapse_workspace_dev_endpoint
+ == self.config["azure_synapse_workspace_dev_endpoint"]
+ )
+ assert operator.check_interval == self.config["check_interval"]
+ assert operator.timeout == self.config["timeout"]
+ assert not operator.wait_for_termination
+
+ with patch.object(
+ AzureSynapsePipelineHook, "get_pipeline_run", autospec=True
+ ) as mock_get_pipeline_run:
+ operator.execute(context=self.mock_context)
+
+ # Check the ``run_id`` attr is assigned after executing the
pipeline.
+ assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"]
+
+ # Check to ensure an `XCom` is pushed regardless of pipeline run
result.
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]
+ )
+
+ mock_run_pipeline.assert_called_once_with(
+ pipeline_name=self.config["pipeline_name"],
+ reference_pipeline_run_id=None,
+ is_recovery=None,
+ start_activity_name=None,
+ parameters=None,
+ )
+
+ # Checking the pipeline run status should _not_ be called when
``wait_for_termination`` is False.
+ mock_get_pipeline_run.assert_not_called()
+
+ @pytest.mark.db_test
+ def test_run_pipeline_operator_link(self,
create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AzureSynapseRunPipelineOperator,
+ dag_id="test_synapse_run_pipeline_op_link",
+ execution_date=DEFAULT_DATE,
+ task_id=AZURE_SYNAPSE_PIPELINE_TASK_ID,
+ azure_synapse_conn_id=AZURE_SYNAPSE_CONN_ID,
+ pipeline_name=PIPELINE_NAME,
+
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+ )
+
+ ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"])
+
+ url = ti.task.get_extra_links(ti, "Monitor Pipeline Run")
+
+ EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = (
+
"https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}"
+ "?workspace=%2Fsubscriptions%2F{subscription_id}%2F"
+ "resourceGroups%2F{resource_group}%2Fproviders%2FMicrosoft.Synapse"
+ "%2Fworkspaces%2F{workspace_name}"
+ )
+
+ conn = AzureSynapsePipelineHook.get_connection(AZURE_SYNAPSE_CONN_ID)
+ conn_synapse_workspace_url = conn.host
+
+ # Extract the workspace_name, subscription_id and resource_group from
the Synapse workspace url.
+ pipeline_run_object = AzureSynapsePipelineRunLink()
+ fields =
pipeline_run_object.get_fields_from_url(workspace_url=conn_synapse_workspace_url)
+
+ assert url == (
+ EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK.format(
+ run_id=PIPELINE_RUN_RESPONSE["run_id"],
+ subscription_id=fields["subscription_id"],
+ resource_group=fields["resource_group"],
+ workspace_name=fields["workspace_name"],
+ )
+ )
diff --git
a/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py
b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py
new file mode 100644
index 0000000000..be11a665c4
--- /dev/null
+++ b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models import DAG
+from airflow.providers.microsoft.azure.operators.synapse import
AzureSynapseRunPipelineOperator
+
+try:
+ from airflow.operators.empty import EmptyOperator
+except ModuleNotFoundError:
+ from airflow.operators.dummy import DummyOperator as EmptyOperator #
type: ignore
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+
+with DAG(
+ dag_id="example_synapse_run_pipeline",
+ start_date=datetime(2021, 8, 13),
+ schedule="@daily",
+ catchup=False,
+ tags=["synapse", "example"],
+) as dag:
+ begin = EmptyOperator(task_id="begin")
+
+ run_pipeline1 = AzureSynapseRunPipelineOperator(
+ task_id="run_pipeline1",
+ azure_synapse_conn_id="azure_synapse_connection",
+ pipeline_name="Pipeline 1",
+
azure_synapse_workspace_dev_endpoint="azure_synapse_workspace_dev_endpoint",
+ )
+
+ begin >> run_pipeline1
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)