samhita-alla commented on a change in pull request #22643: URL: https://github.com/apache/airflow/pull/22643#discussion_r839471836
########## File path: airflow/providers/flyte/operators/flyte.py ########## @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AirflowFlyteOperator(BaseOperator): + """ + Launch Flyte executions from within Airflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AirflowFlyteOperator` + + :param flyte_conn_id: Required. The connection to Flyte setup, containing metadata. + :param project: Optional. The project to connect to. + :param domain: Optional. The domain to connect to. + :param launchplan_name: Optional. The name of the launchplan to trigger. + :param task_name: Optional. The name of the task to trigger. + :param max_parallelism: Optional. The maximum number of parallel executions to allow. + :param raw_data_prefix: Optional. The prefix to use for raw data. + :param assumable_iam_role: Optional. The IAM role to assume. + :param kubernetes_service_account: Optional. The Kubernetes service account to use. + :param labels: Optional. Custom labels to be applied to the execution resource. + :param annotations: Optional. Custom annotations to be applied to the execution resource. + :param version: Optional. The version of the launchplan/task to trigger. + :param inputs: Optional. The inputs to the launchplan/task. + :param timeout: Optional. The timeout to wait for the execution to finish. + :param poll_interval: Optional. The interval between checks to poll the execution. + :param asynchronous: Optional. Whether to wait for the execution to finish or not. + """ + + template_fields: Sequence[str] = ("flyte_conn_id",) # mypy fix + + def __init__( + self, + flyte_conn_id: str, + project: Optional[str] = None, + domain: Optional[str] = None, + launchplan_name: Optional[str] = None, + task_name: Optional[str] = None, + max_parallelism: Optional[int] = None, + raw_data_prefix: Optional[str] = None, + assumable_iam_role: Optional[str] = None, + kubernetes_service_account: Optional[str] = None, + labels: Dict[str, str] = {}, + annotations: Dict[str, str] = {}, + version: Optional[str] = None, + inputs: Dict[str, Any] = {}, + timeout: Optional[timedelta] = None, + poll_interval: timedelta = timedelta(seconds=30), + asynchronous: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.flyte_conn_id = flyte_conn_id + self.project = project + self.domain = domain + self.launchplan_name = launchplan_name + self.task_name = task_name + self.max_parallelism = max_parallelism + self.raw_data_prefix = raw_data_prefix + self.assumable_iam_role = assumable_iam_role + self.kubernetes_service_account = kubernetes_service_account + self.labels = labels + self.annotations = annotations + self.version = version + self.inputs = inputs + self.timeout = timeout + self.poll_interval = poll_interval + self.asynchronous = asynchronous + self.execution_name: str = "" + + if (not (self.task_name or self.launchplan_name)) or (self.task_name and self.launchplan_name): + raise AirflowException("Either task_name or launchplan_name is required.") + + def execute(self, context: "Context") -> str: + """Trigger an execution and wait for it to finish.""" + + # create a deterministic execution name + task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:5] + self.execution_name = task_id + re.sub( + r"[\W_]+", + "", + context["dag_run"].run_id.split("__")[-1].lower(), + )[: (20 - len(task_id))] Review comment: I'm currently generating a deterministic-and-unique execution name that is to be used to name a Flyte execution. It's a combination of the `task_id` and `run_id`. I'm using both because with `task_id`, I wouldn't be able to create unique task names whenever a task runs more than once, and with `run_id` I wouldn't be able to create unique task names within the same DAG cause `run_id` remains the same for all tasks; hence came up with this logic. The execution name cannot exceed 20 characters (a restriction imposed by Flyte), and hence, I'm trimming the two strings. Please let me know if there's a better way to create unique task names, even with the task being repeated multiple times within the same DAG (with different `task_id`s, of course), or run multiple times. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
