This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new ac943c9 [AIRFLOW-3964][AIP-17] Consolidate and de-dup sensor tasks using Smart Sensor (#5499) ac943c9 is described below commit ac943c9e18f75259d531dbda8c51e650f57faa4c Author: Yingbo Wang <ybw...@gmail.com> AuthorDate: Tue Sep 8 14:47:59 2020 -0700 [AIRFLOW-3964][AIP-17] Consolidate and de-dup sensor tasks using Smart Sensor (#5499) Co-authored-by: Yingbo Wang <yingbo.w...@airbnb.com> --- airflow/config_templates/config.yml | 33 + airflow/config_templates/default_airflow.cfg | 15 + airflow/exceptions.py | 7 + airflow/jobs/scheduler_job.py | 7 +- .../e38be357a868_update_schema_for_smart_sensor.py | 92 +++ airflow/models/__init__.py | 1 + airflow/models/baseoperator.py | 7 + airflow/models/dagbag.py | 12 +- airflow/models/sensorinstance.py | 166 +++++ airflow/models/taskinstance.py | 121 +++- .../apache/hive/sensors/metastore_partition.py | 1 + .../apache/hive/sensors/named_hive_partition.py | 10 + .../providers/elasticsearch/log/es_task_handler.py | 40 +- airflow/sensors/base_sensor_operator.py | 74 +- airflow/sensors/smart_sensor_operator.py | 764 +++++++++++++++++++++ airflow/smart_sensor_dags/__init__.py | 17 + airflow/smart_sensor_dags/smart_sensor_group.py | 63 ++ airflow/utils/file.py | 16 +- airflow/utils/log/file_processor_handler.py | 15 +- airflow/utils/log/file_task_handler.py | 22 +- airflow/utils/log/log_reader.py | 3 +- airflow/utils/state.py | 25 +- airflow/www/static/css/graph.css | 4 + airflow/www/static/css/tree.css | 4 + airflow/www/templates/airflow/ti_log.html | 28 +- docs/img/smart_sensor_architecture.png | Bin 0 -> 80325 bytes docs/img/smart_sensor_single_task_execute_flow.png | Bin 0 -> 75462 bytes docs/index.rst | 1 + docs/logging-monitoring/metrics.rst | 5 + docs/operators-and-hooks-ref.rst | 4 + docs/smart-sensor.rst | 86 +++ tests/api_connexion/endpoints/test_log_endpoint.py | 14 +- tests/jobs/test_scheduler_job.py | 22 +- tests/models/test_dagbag.py | 8 +- tests/models/test_sensorinstance.py | 46 ++ .../amazon/aws/log/test_cloudwatch_task_handler.py | 9 +- .../amazon/aws/log/test_s3_task_handler.py | 11 +- .../elasticsearch/log/test_es_task_handler.py | 20 +- .../microsoft/azure/log/test_wasb_task_handler.py | 10 +- tests/sensors/test_smart_sensor_operator.py | 326 +++++++++ tests/test_config_templates.py | 3 +- tests/utils/log/test_log_reader.py | 58 +- tests/utils/test_log_handlers.py | 2 +- tests/www/test_views.py | 22 +- 44 files changed, 2062 insertions(+), 132 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 10b8987..25b24da 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -2337,3 +2337,36 @@ to identify the task. Should be supplied in the format: ``key = value`` options: [] +- name: smart_sensor + description: ~ + options: + - name: use_smart_sensor + description: | + When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to + smart sensor task. + version_added: ~ + type: boolean + example: ~ + default: "False" + - name: shard_code_upper_limit + description: | + `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated + by `hashcode % shard_code_upper_limit`. + version_added: ~ + type: int + example: ~ + default: "10000" + - name: shards + description: | + The number of running smart sensor processes for each service. + version_added: ~ + type: int + example: ~ + default: "5" + - name: sensors_enabled + description: | + comma separated sensor classes support in smart_sensor. + version_added: ~ + type: string + example: ~ + default: "NamedHivePartitionSensor" diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index d5c5262..24ba5a1 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -1179,3 +1179,18 @@ worker_resources = # The worker pods will be given these static labels, as well as some additional dynamic labels # to identify the task. # Should be supplied in the format: ``key = value`` + +[smart_sensor] +# When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to +# smart sensor task. +use_smart_sensor = False + +# `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated +# by `hashcode % shard_code_upper_limit`. +shard_code_upper_limit = 10000 + +# The number of running smart sensor processes for each service. +shards = 5 + +# comma separated sensor classes support in smart_sensor. +sensors_enabled = NamedHivePartitionSensor diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 22d20c5..0406ce7 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -63,6 +63,13 @@ class AirflowRescheduleException(AirflowException): self.reschedule_date = reschedule_date +class AirflowSmartSensorException(AirflowException): + """ + Raise after the task register itself in the smart sensor service + It should exit without failing a task + """ + + class InvalidStatsNameException(AirflowException): """Raise when name of the stats is invalid""" diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 422f5a9..cb43f8d 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -862,7 +862,7 @@ class DagFileProcessor(LoggingMixin): self.log.info("Processing file %s for tasks to queue", file_path) try: - dagbag = DagBag(file_path, include_examples=False) + dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False) except Exception: # pylint: disable=broad-except self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) @@ -1743,7 +1743,10 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes # NONE so we don't try to re-run it. self._change_state_for_tis_without_dagrun( simple_dag_bag=simple_dag_bag, - old_states=[State.QUEUED, State.SCHEDULED, State.UP_FOR_RESCHEDULE], + old_states=[State.QUEUED, + State.SCHEDULED, + State.UP_FOR_RESCHEDULE, + State.SENSING], new_state=State.NONE ) self._execute_task_instances(simple_dag_bag) diff --git a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py new file mode 100644 index 0000000..27227ae --- /dev/null +++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add sensor_instance table + +Revision ID: e38be357a868 +Revises: 939bb1e647c8 +Create Date: 2019-06-07 04:03:17.003939 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy import func +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = 'e38be357a868' +down_revision = 'da3f683c3a5a' +branch_labels = None +depends_on = None + + +def mssql_timestamp(): # noqa: D103 + return sa.DateTime() + + +def mysql_timestamp(): # noqa: D103 + return mysql.TIMESTAMP(fsp=6) + + +def sa_timestamp(): # noqa: D103 + return sa.TIMESTAMP(timezone=True) + + +def upgrade(): # noqa: D103 + + conn = op.get_bind() + if conn.dialect.name == 'mysql': + timestamp = mysql_timestamp + elif conn.dialect.name == 'mssql': + timestamp = mssql_timestamp + else: + timestamp = sa_timestamp + + op.create_table( + 'sensor_instance', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('task_id', sa.String(length=250), nullable=False), + sa.Column('dag_id', sa.String(length=250), nullable=False), + sa.Column('execution_date', timestamp(), nullable=False), + sa.Column('state', sa.String(length=20), nullable=True), + sa.Column('try_number', sa.Integer(), nullable=True), + sa.Column('start_date', timestamp(), nullable=True), + sa.Column('operator', sa.String(length=1000), nullable=False), + sa.Column('op_classpath', sa.String(length=1000), nullable=False), + sa.Column('hashcode', sa.BigInteger(), nullable=False), + sa.Column('shardcode', sa.Integer(), nullable=False), + sa.Column('poke_context', sa.Text(), nullable=False), + sa.Column('execution_context', sa.Text(), nullable=True), + sa.Column('created_at', timestamp(), default=func.now(), nullable=False), + sa.Column('updated_at', timestamp(), default=func.now(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index( + 'ti_primary_key', + 'sensor_instance', + ['dag_id', 'task_id', 'execution_date'], + unique=True + ) + op.create_index('si_hashcode', 'sensor_instance', ['hashcode'], unique=False) + op.create_index('si_shardcode', 'sensor_instance', ['shardcode'], unique=False) + op.create_index('si_state_shard', 'sensor_instance', ['state', 'shardcode'], unique=False) + op.create_index('si_updated_at', 'sensor_instance', ['updated_at'], unique=False) + + +def downgrade(): # noqa: D103 + op.drop_table('sensor_instance') diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index 94a4b2b..4a53845 100644 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -27,6 +27,7 @@ from airflow.models.errors import ImportError # pylint: disable=redefined-built from airflow.models.log import Log from airflow.models.pool import Pool from airflow.models.renderedtifields import RenderedTaskInstanceFields +from airflow.models.sensorinstance import SensorInstance # noqa: F401 from airflow.models.skipmixin import SkipMixin from airflow.models.slamiss import SlaMiss from airflow.models.taskfail import TaskFail diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 352ebcd..27013f2 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1334,6 +1334,13 @@ class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta): return cls.__serialized_fields + def is_smart_sensor_compatible(self): + """ + Return if this operator can use smart service. Default False. + + """ + return False + def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]): r""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index a9def70..4acc4c7 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -70,6 +70,9 @@ class DagBag(BaseDagBag, LoggingMixin): :param include_examples: whether to include the examples that ship with airflow or not :type include_examples: bool + :param include_smart_sensor: whether to include the smart sensor native + DAGs that create the smart sensor operators for whole cluster + :type include_smart_sensor: bool :param read_dags_from_db: Read DAGs from DB if store_serialized_dags is ``True``. If ``False`` DAGs are read from python files. This property is not used when determining whether or not to write Serialized DAGs, that is done by checking @@ -84,6 +87,7 @@ class DagBag(BaseDagBag, LoggingMixin): self, dag_folder: Optional[str] = None, include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'), + include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'), read_dags_from_db: bool = False, store_serialized_dags: Optional[bool] = None, @@ -113,6 +117,7 @@ class DagBag(BaseDagBag, LoggingMixin): self.collect_dags( dag_folder=dag_folder, include_examples=include_examples, + include_smart_sensor=include_smart_sensor, safe_mode=safe_mode) def size(self) -> int: @@ -391,6 +396,7 @@ class DagBag(BaseDagBag, LoggingMixin): dag_folder=None, only_if_updated=True, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'), + include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'), safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')): """ Given a file path or a folder, this method looks for python modules, @@ -414,8 +420,10 @@ class DagBag(BaseDagBag, LoggingMixin): stats = [] dag_folder = correct_maybe_zipped(dag_folder) - for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode, - include_examples=include_examples): + for filepath in list_py_file_paths(dag_folder, + safe_mode=safe_mode, + include_examples=include_examples, + include_smart_sensor=include_smart_sensor): try: file_parse_start_dttm = timezone.utcnow() found_dags = self.process_file( diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py new file mode 100644 index 0000000..88aba4f --- /dev/null +++ b/airflow/models/sensorinstance.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +from sqlalchemy import BigInteger, Column, Index, Integer, String, Text + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.models.base import ID_LEN, Base +from airflow.utils import timezone +from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.state import State + + +class SensorInstance(Base): + """ + SensorInstance support the smart sensor service. It stores the sensor task states + and context that required for poking include poke context and execution context. + In sensor_instance table we also save the sensor operator classpath so that inside + smart sensor there is no need to import the dagbag and create task object for each + sensor task. + + SensorInstance include another set of columns to support the smart sensor shard on + large number of sensor instance. The key idea is to generate the hash code from the + poke context and use it to map to a shorter shard code which can be used as an index. + Every smart sensor process takes care of tasks whose `shardcode` are in a certain range. + + """ + + __tablename__ = "sensor_instance" + + id = Column(Integer, primary_key=True) + task_id = Column(String(ID_LEN), nullable=False) + dag_id = Column(String(ID_LEN), nullable=False) + execution_date = Column(UtcDateTime, nullable=False) + state = Column(String(20)) + _try_number = Column('try_number', Integer, default=0) + start_date = Column(UtcDateTime) + operator = Column(String(1000), nullable=False) + op_classpath = Column(String(1000), nullable=False) + hashcode = Column(BigInteger, nullable=False) + shardcode = Column(Integer, nullable=False) + poke_context = Column(Text, nullable=False) + execution_context = Column(Text) + created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False) + updated_at = Column(UtcDateTime, + default=timezone.utcnow(), + onupdate=timezone.utcnow(), + nullable=False) + + __table_args__ = ( + Index('ti_primary_key', dag_id, task_id, execution_date, unique=True), + + Index('si_hashcode', hashcode), + Index('si_shardcode', shardcode), + Index('si_state_shard', state, shardcode), + Index('si_updated_at', updated_at), + ) + + def __init__(self, ti): + self.dag_id = ti.dag_id + self.task_id = ti.task_id + self.execution_date = ti.execution_date + + @staticmethod + def get_classpath(obj): + """ + Get the object dotted class path. Used for getting operator classpath. + + :param obj: + :type obj: + :return: The class path of input object + :rtype: str + """ + module_name, class_name = obj.__module__, obj.__class__.__name__ + + return module_name + "." + class_name + + @classmethod + @provide_session + def register(cls, ti, poke_context, execution_context, session=None): + """ + Register task instance ti for a sensor in sensor_instance table. Persist the + context used for a sensor and set the sensor_instance table state to sensing. + + :param ti: The task instance for the sensor to be registered. + :type: ti: + :param poke_context: Context used for sensor poke function. + :type poke_context: dict + :param execution_context: Context used for execute sensor such as timeout + setting and email configuration. + :type execution_context: dict + :param session: SQLAlchemy ORM Session + :type session: Session + :return: True if the ti was registered successfully. + :rtype: Boolean + """ + if poke_context is None: + raise AirflowException('poke_context should not be None') + + encoded_poke = json.dumps(poke_context) + encoded_execution_context = json.dumps(execution_context) + + sensor = session.query(SensorInstance).filter( + SensorInstance.dag_id == ti.dag_id, + SensorInstance.task_id == ti.task_id, + SensorInstance.execution_date == ti.execution_date + ).with_for_update().first() + + if sensor is None: + sensor = SensorInstance(ti=ti) + + sensor.operator = ti.operator + sensor.op_classpath = SensorInstance.get_classpath(ti.task) + sensor.poke_context = encoded_poke + sensor.execution_context = encoded_execution_context + + sensor.hashcode = hash(encoded_poke) + sensor.shardcode = sensor.hashcode % conf.getint('smart_sensor', 'shard_code_upper_limit') + sensor.try_number = ti.try_number + + sensor.state = State.SENSING + sensor.start_date = timezone.utcnow() + session.add(sensor) + session.commit() + + return True + + @property + def try_number(self): + """ + Return the try number that this task number will be when it is actually + run. + If the TI is currently running, this will match the column in the + database, in all other cases this will be incremented. + """ + # This is designed so that task logs end up in the right file. + if self.state in State.running(): + return self._try_number + return self._try_number + 1 + + @try_number.setter + def try_number(self, value): + self._try_number = value + + def __repr__(self): + return "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} " \ + "execution_context: {self.execution_context} state: {self.state}>".format( + self=self) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index cb5a1bc..b3220ae 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -29,6 +29,7 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union from urllib.parse import quote import dill +import jinja2 import lazy_object_proxy import pendulum from jinja2 import TemplateAssertionError, UndefinedError @@ -41,7 +42,7 @@ from airflow import settings from airflow.configuration import conf from airflow.exceptions import ( AirflowException, AirflowFailException, AirflowRescheduleException, AirflowSkipException, - AirflowTaskTimeout, + AirflowSmartSensorException, AirflowTaskTimeout, ) from airflow.models.base import COLLATION_ARGS, ID_LEN, Base from airflow.models.log import Log @@ -281,7 +282,8 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 database, in all other cases this will be incremented. """ # This is designed so that task logs end up in the right file. - if self.state == State.RUNNING: + # TODO: whether we need sensing here or not (in sensor and task_instance state machine) + if self.state in State.running(): return self._try_number return self._try_number + 1 @@ -1072,6 +1074,9 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 self._prepare_and_execute_task_with_callbacks(context, task) self.refresh_from_db(lock_for_update=True) self.state = State.SUCCESS + except AirflowSmartSensorException as e: + self.log.info(e) + return except AirflowSkipException as e: # Recording SKIP # log only if exception has any arguments to prevent log flooding @@ -1172,6 +1177,20 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 # Run on_execute callback self._run_execute_callback(context, task) + if task_copy.is_smart_sensor_compatible(): + # Try to register it in the smart sensor service. + registered = False + try: + registered = task_copy.register_in_sensor_service(self, context) + except Exception as e: + self.log.warning("Failed to register in sensor service." + "Continue to run task in non smart sensor mode.") + self.log.exception(e, exc_info=True) + + if registered: + # Will raise AirflowSmartSensorException to avoid long running execution. + self._update_ti_state_for_sensing() + # Execute the task with set_current_context(context): result = self._execute_task(context, task_copy) @@ -1187,6 +1206,16 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 Stats.incr('operator_successes_{}'.format(self.task.__class__.__name__), 1, 1) Stats.incr('ti_successes') + @provide_session + def _update_ti_state_for_sensing(self, session=None): + self.log.info('Submitting %s to sensor service', self) + self.state = State.SENSING + self.start_date = timezone.utcnow() + session.merge(self) + session.commit() + # Raise exception for sensing state + raise AirflowSmartSensorException("Task successfully registered in smart sensor.") + def _run_success_callback(self, context, task): """Functions that need to be run if Task is successful""" # Success callback @@ -1580,19 +1609,12 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 self.task.render_template_fields(context) - def email_alert(self, exception): - """Send Email Alert with exception trace""" + def get_email_subject_content(self, exception): + """Get the email subject content for exceptions.""" + # For a ti from DB (without ti.task), return the default value + # Reuse it for smart sensor to send default email alert + use_default = not hasattr(self, 'task') exception_html = str(exception).replace('\n', '<br>') - jinja_context = self.get_template_context() - # This function is called after changing the state - # from State.RUNNING so use prev_attempted_tries. - jinja_context.update(dict( - exception=exception, - exception_html=exception_html, - try_number=self.prev_attempted_tries, - max_tries=self.max_tries)) - - jinja_env = self.task.get_template_env() default_subject = 'Airflow alert: {{ti}}' # For reporting purposes, we report based on 1-indexed, @@ -1607,29 +1629,62 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' ) - def render(key, content): - if conf.has_option('email', key): - path = conf.get('email', key) - with open(path) as file: - content = file.read() + default_html_content_err = ( + 'Try {{try_number}} out of {{max_tries + 1}}<br>' + 'Exception:<br>Failed attempt to attach error logs<br>' + 'Log: <a href="{{ti.log_url}}">Link</a><br>' + 'Host: {{ti.hostname}}<br>' + 'Log file: {{ti.log_filepath}}<br>' + 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' + ) + + if use_default: + jinja_context = {'ti': self} + # This function is called after changing the state + # from State.RUNNING so need to subtract 1 from self.try_number. + jinja_context.update(dict( + exception=exception, + exception_html=exception_html, + try_number=self.try_number - 1, + max_tries=self.max_tries)) + + jinja_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), + autoescape=True) + subject = jinja_env.from_string(default_subject).render(**jinja_context) + html_content = jinja_env.from_string(default_html_content).render(**jinja_context) + html_content_err = jinja_env.from_string(default_html_content_err).render(**jinja_context) + + else: + jinja_context = self.get_template_context() + + jinja_context.update(dict( + exception=exception, + exception_html=exception_html, + try_number=self.try_number - 1, + max_tries=self.max_tries)) - return jinja_env.from_string(content).render(**jinja_context) + jinja_env = self.task.get_template_env() - subject = render('subject_template', default_subject) - html_content = render('html_content_template', default_html_content) - # noinspection PyBroadException + def render(key, content): + if conf.has_option('email', key): + path = conf.get('email', key) + with open(path) as f: + content = f.read() + return jinja_env.from_string(content).render(**jinja_context) + + subject = render('subject_template', default_subject) + html_content = render('html_content_template', default_html_content) + html_content_err = render('html_content_template', default_html_content_err) + + return subject, html_content, html_content_err + + def email_alert(self, exception): + """Send alert email with exception information.""" + subject, html_content, html_content_err = self.get_email_subject_content(exception) try: send_email(self.task.email, subject, html_content) - except Exception: # pylint: disable=broad-except - default_html_content_err = ( - 'Try {{try_number}} out of {{max_tries + 1}}<br>' - 'Exception:<br>Failed attempt to attach error logs<br>' - 'Log: <a href="{{ti.log_url}}">Link</a><br>' - 'Host: {{ti.hostname}}<br>' - 'Log file: {{ti.log_filepath}}<br>' - 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' - ) - html_content_err = render('html_content_template', default_html_content_err) + except Exception: send_email(self.task.email, subject, html_content_err) def set_duration(self) -> None: diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py index 31376ad..0955291 100644 --- a/airflow/providers/apache/hive/sensors/metastore_partition.py +++ b/airflow/providers/apache/hive/sensors/metastore_partition.py @@ -44,6 +44,7 @@ class MetastorePartitionSensor(SqlSensor): template_fields = ('partition_name', 'table', 'schema') ui_color = '#8da7be' + poke_context_fields = ('partition_name', 'table', 'schema', 'mysql_conn_id') @apply_defaults def __init__( diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py index 23d9466..b8cb600 100644 --- a/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -40,6 +40,7 @@ class NamedHivePartitionSensor(BaseSensorOperator): template_fields = ('partition_names',) ui_color = '#8d99ae' + poke_context_fields = ('partition_names', 'metastore_conn_id') @apply_defaults def __init__( @@ -104,3 +105,12 @@ class NamedHivePartitionSensor(BaseSensorOperator): self.next_index_to_poke = 0 return True + + def is_smart_sensor_compatible(self): + result = ( + not self.soft_fail + and not self.hook + and len(self.partition_names) <= 30 + and super().is_smart_sensor_compatible() + ) + return result diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 93f9a1b..d3d07c1 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -18,9 +18,10 @@ import logging import sys +from collections import defaultdict from datetime import datetime from time import time -from typing import Optional, Tuple +from typing import List, Optional, Tuple from urllib.parse import quote # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. @@ -36,6 +37,9 @@ from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.json_formatter import JSONFormatter from airflow.utils.log.logging_mixin import LoggingMixin +# Elasticsearch hosted log type +EsLogMsgType = List[Tuple[str, str]] + class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): """ @@ -118,7 +122,24 @@ class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): """ return execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f") - def _read(self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None) -> Tuple[str, dict]: + @staticmethod + def _group_logs_by_host(logs): + grouped_logs = defaultdict(list) + for log in logs: + key = getattr(log, 'host', 'default_host') + grouped_logs[key].append(log) + + # return items sorted by timestamp. + result = sorted(grouped_logs.items(), key=lambda kv: getattr(kv[1][0], 'message', '_')) + + return result + + def _read_grouped_logs(self): + return True + + def _read( + self, ti: TaskInstance, try_number: int, metadata: Optional[dict] = None + ) -> Tuple[EsLogMsgType, dict]: """ Endpoint for streaming log. @@ -126,7 +147,7 @@ class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): :param try_number: try_number of the task instance :param metadata: log metadata, can be used for steaming log reading and auto-tailing. - :return: a list of log documents and metadata. + :return: a list of tuple with host and log documents, metadata. """ if not metadata: metadata = {'offset': 0} @@ -137,6 +158,7 @@ class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): log_id = self._render_log_id(ti, try_number) logs = self.es_read(log_id, offset, metadata) + logs_by_host = self._group_logs_by_host(logs) next_offset = offset if not logs else logs[-1].offset @@ -147,7 +169,10 @@ class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): # end_of_log_mark may contain characters like '\n' which is needed to # have the log uploaded but will not be stored in elasticsearch. - metadata['end_of_log'] = False if not logs else logs[-1].message == self.end_of_log_mark.strip() + loading_hosts = [ + item[0] for item in logs_by_host if item[-1][-1].message != self.end_of_log_mark.strip() + ] + metadata['end_of_log'] = False if not logs else len(loading_hosts) == 0 cur_ts = pendulum.now() # Assume end of log after not receiving new log for 5 min, @@ -167,8 +192,11 @@ class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin): # If we hit the end of the log, remove the actual end_of_log message # to prevent it from showing in the UI. - i = len(logs) if not metadata['end_of_log'] else len(logs) - 1 - message = '\n'.join([log.message for log in logs[0:i]]) + def concat_logs(lines): + log_range = (len(lines) - 1) if lines[-1].message == self.end_of_log_mark.strip() else len(lines) + return '\n'.join([lines[i].message for i in range(log_range)]) + + message = [(host, concat_logs(hosted_log)) for host, hosted_log in logs_by_host] return message, metadata diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index c4bfd43..528c04c 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import datetime import hashlib import os from datetime import timedelta @@ -26,7 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import ( AirflowException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, ) -from airflow.models import BaseOperator +from airflow.models import BaseOperator, SensorInstance from airflow.models.skipmixin import SkipMixin from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep @@ -68,6 +69,15 @@ class BaseSensorOperator(BaseOperator, SkipMixin): ui_color = '#e6f1f2' # type: str valid_modes = ['poke', 'reschedule'] # type: Iterable[str] + # As the poke context in smart sensor defines the poking job signature only, + # The execution_fields defines other execution details + # for this tasks such as the customer defined timeout, the email and the alert + # setup. Smart sensor serialize these attributes into a different DB column so + # that smart sensor service is able to handle corresponding execution details + # without breaking the sensor poking logic with dedup. + execution_fields = ('poke_interval', 'retries', 'execution_timeout', 'timeout', + 'email', 'email_on_retry', 'email_on_failure',) + @apply_defaults def __init__(self, *, poke_interval: float = 60, @@ -83,6 +93,9 @@ class BaseSensorOperator(BaseOperator, SkipMixin): self.mode = mode self.exponential_backoff = exponential_backoff self._validate_input_values() + self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') + self.sensors_support_sensor_service = set( + map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(','))) def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: @@ -106,6 +119,65 @@ class BaseSensorOperator(BaseOperator, SkipMixin): """ raise AirflowException('Override me.') + def is_smart_sensor_compatible(self): + check_list = [not self.sensor_service_enabled, + self.on_success_callback, + self.on_retry_callback, + self.on_failure_callback] + for status in check_list: + if status: + return False + + operator = self.__class__.__name__ + return operator in self.sensors_support_sensor_service + + def register_in_sensor_service(self, ti, context): + """ + Register ti in smart sensor service + + :param ti: Task instance object. + :param context: TaskInstance template context from the ti. + :return: boolean + """ + poke_context = self.get_poke_context(context) + execution_context = self.get_execution_context(context) + + return SensorInstance.register(ti, poke_context, execution_context) + + def get_poke_context(self, context): + """ + Return a dictionary with all attributes in poke_context_fields. The + poke_context with operator class can be used to identify a unique + sensor job. + + :param context: TaskInstance template context. + :return: A dictionary with key in poke_context_fields. + """ + if not context: + self.log.info("Function get_poke_context doesn't have a context input.") + + poke_context_fields = getattr(self.__class__, "poke_context_fields", None) + result = {key: getattr(self, key, None) for key in poke_context_fields} + return result + + def get_execution_context(self, context): + """ + Return a dictionary with all attributes in execution_fields. The + execution_context include execution requirement for each sensor task + such as timeout setup, email_alert setup. + + :param context: TaskInstance template context. + :return: A dictionary with key in execution_fields. + """ + if not context: + self.log.info("Function get_execution_context doesn't have a context input.") + execution_fields = self.__class__.execution_fields + + result = {key: getattr(self, key, None) for key in execution_fields} + if result['execution_timeout'] and isinstance(result['execution_timeout'], datetime.timedelta): + result['execution_timeout'] = result['execution_timeout'].total_seconds() + return result + def execute(self, context: Dict) -> Any: started_at = timezone.utcnow() try_number = 1 diff --git a/airflow/sensors/smart_sensor_operator.py b/airflow/sensors/smart_sensor_operator.py new file mode 100644 index 0000000..d6b6df2 --- /dev/null +++ b/airflow/sensors/smart_sensor_operator.py @@ -0,0 +1,764 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import datetime +import json +import logging +import time +import traceback +from logging.config import DictConfigurator # type: ignore +from time import sleep + +from sqlalchemy import and_, or_, tuple_ + +from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.models import BaseOperator, SensorInstance, SkipMixin, TaskInstance +from airflow.settings import LOGGING_CLASS_PATH +from airflow.stats import Stats +from airflow.utils import helpers, timezone +from airflow.utils.decorators import apply_defaults +from airflow.utils.email import send_email +from airflow.utils.log.logging_mixin import set_context +from airflow.utils.module_loading import import_string +from airflow.utils.net import get_hostname +from airflow.utils.session import provide_session +from airflow.utils.state import PokeState, State +from airflow.utils.timeout import timeout + +config = import_string(LOGGING_CLASS_PATH) +handler_config = config['handlers']['task'] +try: + formatter_config = config['formatters'][handler_config['formatter']] +except Exception as err: # pylint: disable=broad-except + formatter_config = None + print(err) +dictConfigurator = DictConfigurator(config) + + +class SensorWork: + """ + This class stores a sensor work with decoded context value. It is only used + inside of smart sensor. Create a sensor work based on sensor instance record. + A sensor work object has the following attributes: + `dag_id`: sensor_instance dag_id. + `task_id`: sensor_instance task_id. + `execution_date`: sensor_instance execution_date. + `try_number`: sensor_instance try_number + `poke_context`: Decoded poke_context for the sensor task. + `execution_context`: Decoded execution_context. + `hashcode`: This is the signature of poking job. + `operator`: The sensor operator class. + `op_classpath`: The sensor operator class path + `encoded_poke_context`: The raw data from sensor_instance poke_context column. + `log`: The sensor work logger which will mock the corresponding task instance log. + + :param si: The sensor_instance ORM object. + """ + def __init__(self, si): + self.dag_id = si.dag_id + self.task_id = si.task_id + self.execution_date = si.execution_date + self.try_number = si.try_number + + self.poke_context = json.loads(si.poke_context) if si.poke_context else {} + self.execution_context = json.loads(si.execution_context) if si.execution_context else {} + try: + self.log = self._get_sensor_logger(si) + except Exception as e: # pylint: disable=broad-except + self.log = None + print(e) + self.hashcode = si.hashcode + self.start_date = si.start_date + self.operator = si.operator + self.op_classpath = si.op_classpath + self.encoded_poke_context = si.poke_context + + def __eq__(self, other): + if not isinstance(other, SensorWork): + return NotImplemented + + return self.dag_id == other.dag_id and \ + self.task_id == other.task_id and \ + self.execution_date == other.execution_date and \ + self.try_number == other.try_number + + @staticmethod + def create_new_task_handler(): + """ + Create task log handler for a sensor work. + :return: log handler + """ + handler_config_copy = {k: handler_config[k] for k in handler_config} + formatter_config_copy = {k: formatter_config[k] for k in formatter_config} + handler = dictConfigurator.configure_handler(handler_config_copy) + formatter = dictConfigurator.configure_formatter(formatter_config_copy) + handler.setFormatter(formatter) + return handler + + def _get_sensor_logger(self, si): + """ + Return logger for a sensor instance object. + """ + # The created log_id is used inside of smart sensor as the key to fetch + # the corresponding in memory log handler. + si.raw = False # Otherwise set_context will fail + log_id = "-".join([si.dag_id, + si.task_id, + si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"), + str(si.try_number)]) + logger = logging.getLogger('airflow.task' + '.' + log_id) + + if len(logger.handlers) == 0: + handler = self.create_new_task_handler() + logger.addHandler(handler) + set_context(logger, si) + + line_break = ("-" * 120) + logger.info(line_break) + logger.info("Processing sensor task %s in smart sensor service on host: %s", + self.ti_key, get_hostname()) + logger.info(line_break) + return logger + + def close_sensor_logger(self): + """ + Close log handler for a sensor work. + """ + for handler in self.log.handlers: + try: + handler.close() + except Exception as e: # pylint: disable=broad-except + print(e) + + @property + def ti_key(self): + """ + Key for the task instance that maps to the sensor work. + """ + return self.dag_id, self.task_id, self.execution_date + + @property + def cache_key(self): + """ + Key used to query in smart sensor for cached sensor work. + """ + return self.operator, self.encoded_poke_context + + +class CachedPokeWork: + """ + Wrapper class for the poke work inside smart sensor. It saves + the sensor_task used to poke and recent poke result state. + state: poke state. + sensor_task: The cached object for executing the poke function. + last_poke_time: The latest time this cached work being called. + to_flush: If we should flush the cached work. + """ + def __init__(self): + self.state = None + self.sensor_task = None + self.last_poke_time = None + self.to_flush = False + + def set_state(self, state): + """ + Set state for cached poke work. + :param state: The sensor_instance state. + """ + self.state = state + self.last_poke_time = timezone.utcnow() + + def clear_state(self): + """ + Clear state for cached poke work. + """ + self.state = None + + def set_to_flush(self): + """ + Mark this poke work to be popped from cached dict after current loop. + """ + self.to_flush = True + + def is_expired(self): + """ + The cached task object expires if there is no poke for 20 minutes. + :return: Boolean + """ + return self.to_flush or (timezone.utcnow() - self.last_poke_time).total_seconds() > 1200 + + +class SensorExceptionInfo: + """ + Hold sensor exception information and the type of exception. For possible transient + infra failure, give the task more chance to retry before fail it. + """ + def __init__(self, + exception_info, + is_infra_failure=False, + infra_failure_retry_window=datetime.timedelta(minutes=130)): + self._exception_info = exception_info + self._is_infra_failure = is_infra_failure + self._infra_failure_retry_window = infra_failure_retry_window + + self._infra_failure_timeout = None + self.set_infra_failure_timeout() + self.fail_current_run = self.should_fail_current_run() + + def set_latest_exception(self, exception_info, is_infra_failure=False): + """ + This function set the latest exception information for sensor exception. If the exception + implies an infra failure, this function will check the recorded infra failure timeout + which was set at the first infra failure exception arrives. There is a 6 hours window + for retry without failing current run. + + :param exception_info: Details of the exception information. + :param is_infra_failure: If current exception was caused by transient infra failure. + There is a retry window _infra_failure_retry_window that the smart sensor will + retry poke function without failing current task run. + """ + self._exception_info = exception_info + self._is_infra_failure = is_infra_failure + + self.set_infra_failure_timeout() + self.fail_current_run = self.should_fail_current_run() + + def set_infra_failure_timeout(self): + """ + Set the time point when the sensor should be failed if it kept getting infra + failure. + :return: + """ + # Only set the infra_failure_timeout if there is no existing one + if not self._is_infra_failure: + self._infra_failure_timeout = None + elif self._infra_failure_timeout is None: + self._infra_failure_timeout = timezone.utcnow() + self._infra_failure_retry_window + + def should_fail_current_run(self): + """ + :return: Should the sensor fail + :type: boolean + """ + return not self.is_infra_failure or timezone.utcnow() > self._infra_failure_timeout + + @property + def exception_info(self): + """ + :return: exception msg. + """ + return self._exception_info + + @property + def is_infra_failure(self): + """ + + :return: If the exception is an infra failure + :type: boolean + """ + return self._is_infra_failure + + def is_expired(self): + """ + :return: If current exception need to be kept. + :type: boolean + """ + if not self._is_infra_failure: + return True + return timezone.utcnow() > self._infra_failure_timeout + datetime.timedelta(minutes=30) + + +class SmartSensorOperator(BaseOperator, SkipMixin): + """ + Smart sensor operators are derived from this class. + + Smart Sensor operators keep refresh a dictionary by visiting DB. + Taking qualified active sensor tasks. Different from sensor operator, + Smart sensor operators poke for all sensor tasks in the dictionary at + a time interval. When a criteria is met or fail by time out, it update + all sensor task state in task_instance table + + :param soft_fail: Set to true to mark the task as SKIPPED on failure + :type soft_fail: bool + :param poke_interval: Time in seconds that the job should wait in + between each tries. + :type poke_interval: int + :param smart_sensor_timeout: Time, in seconds before the internal sensor + job times out if poke_timeout is not defined. + :type smart_sensor_timeout: int + :param shard_min: shard code lower bound (inclusive) + :type shard_min: int + :param shard_max: shard code upper bound (exclusive) + :type shard_max: int + :param poke_exception_cache_ttl: Time, in seconds before the current + exception expires and being cleaned up. + :type poke_exception_cache_ttl: int + :param poke_timeout: Time, in seconds before the task times out and fails. + :type poke_timeout: int + """ + ui_color = '#e6f1f2' + + @apply_defaults + def __init__(self, + poke_interval=180, + smart_sensor_timeout=60 * 60 * 24 * 7, + soft_fail=False, + shard_min=0, + shard_max=100000, + poke_exception_cache_ttl=600, + poke_timeout=6, + *args, + **kwargs): + super().__init__(*args, **kwargs) + # super(SmartSensorOperator, self).__init__(*args, **kwargs) + self.poke_interval = poke_interval + self.soft_fail = soft_fail + self.timeout = smart_sensor_timeout + self._validate_input_values() + self.hostname = "" + + self.sensor_works = [] + self.cached_dedup_works = {} + self.cached_sensor_exceptions = {} + + self.max_tis_per_query = 50 + self.shard_min = shard_min + self.shard_max = shard_max + self.poke_exception_cache_ttl = poke_exception_cache_ttl + self.poke_timeout = poke_timeout + + def _validate_input_values(self): + if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: + raise AirflowException( + "The poke_interval must be a non-negative number") + if not isinstance(self.timeout, (int, float)) or self.timeout < 0: + raise AirflowException( + "The timeout must be a non-negative number") + + @provide_session + def _load_sensor_works(self, session=None): + """ + Refresh sensor instances need to be handled by this operator. Create smart sensor + internal object based on the information persisted in the sensor_instance table. + + """ + SI = SensorInstance + start_query_time = time.time() + query = session.query(SI) \ + .filter(SI.state == State.SENSING)\ + .filter(SI.shardcode < self.shard_max, + SI.shardcode >= self.shard_min) + tis = query.all() + + self.log.info("Performance query %s tis, time: %s", len(tis), time.time() - start_query_time) + + # Query without checking dagrun state might keep some failed dag_run tasks alive. + # Join with DagRun table will be very slow based on the number of sensor tasks we + # need to handle. We query all smart tasks in this operator + # and expect scheduler correct the states in _change_state_for_tis_without_dagrun() + + sensor_works = [] + for ti in tis: + try: + sensor_works.append(SensorWork(ti)) + except Exception as e: # pylint: disable=broad-except + self.log.exception("Exception at creating sensor work for ti %s", ti.key) + self.log.exception(e, exc_info=True) + + self.log.info("%d tasks detected.", len(sensor_works)) + + new_sensor_works = [x for x in sensor_works if x not in self.sensor_works] + + self._update_ti_hostname(new_sensor_works) + + self.sensor_works = sensor_works + + @provide_session + def _update_ti_hostname(self, sensor_works, session=None): + """ + Update task instance hostname for new sensor works. + + :param sensor_works: Smart sensor internal object for a sensor task. + :param session: The sqlalchemy session. + """ + TI = TaskInstance + ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in sensor_works] + + def update_ti_hostname_with_count(count, ti_keys): + # Using or_ instead of in_ here to prevent from full table scan. + tis = session.query(TI) \ + .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key + for ti_key in ti_keys)) \ + .all() + + for ti in tis: + ti.hostname = self.hostname + session.commit() + + return count + len(ti_keys) + + count = helpers.reduce_in_chunks(update_ti_hostname_with_count, ti_keys, 0, self.max_tis_per_query) + if count: + self.log.info("Updated hostname on %s tis.", count) + + @provide_session + def _mark_multi_state(self, operator, poke_hash, encoded_poke_context, state, session=None): + """ + Mark state for multiple tasks in the task_instance table to a new state if they have + the same signature as the poke_hash. + + :param operator: The sensor's operator class name. + :param poke_hash: The hash code generated from sensor's poke context. + :param encoded_poke_context: The raw encoded poke_context. + :param state: Set multiple sensor tasks to this state. + :param session: The sqlalchemy session. + """ + def mark_state(ti, sensor_instance): + ti.state = state + sensor_instance.state = state + if state in State.finished(): + ti.end_date = end_date + ti.set_duration() + + SI = SensorInstance + TI = TaskInstance + + count_marked = 0 + try: + query_result = session.query(TI, SI)\ + .join(TI, and_(TI.dag_id == SI.dag_id, + TI.task_id == SI.task_id, + TI.execution_date == SI.execution_date)) \ + .filter(SI.state == State.SENSING) \ + .filter(SI.hashcode == poke_hash) \ + .filter(SI.operator == operator) \ + .with_for_update().all() + + end_date = timezone.utcnow() + for ti, sensor_instance in query_result: + if sensor_instance.poke_context != encoded_poke_context: + continue + + ti.hostname = self.hostname + if ti.state == State.SENSING: + mark_state(ti=ti, sensor_instance=sensor_instance) + count_marked += 1 + else: + # ti.state != State.SENSING + sensor_instance.state = ti.state + + session.commit() + + except Exception as e: # pylint: disable=broad-except + self.log.warning("Exception _mark_multi_state in smart sensor for hashcode %s", + str(poke_hash)) + self.log.exception(e, exc_info=True) + self.log.info("Marked %s tasks out of %s to state %s", count_marked, len(query_result), state) + + @provide_session + def _retry_or_fail_task(self, sensor_work, error, session=None): + """ + Change single task state for sensor task. For final state, set the end_date. + Since smart sensor take care all retries in one process. Failed sensor tasks + logically experienced all retries and the try_number should be set to max_tries. + + :param sensor_work: The sensor_work with exception. + :type sensor_work: SensorWork + :param error: The error message for this sensor_work. + :type error: str. + :param session: The sqlalchemy session. + """ + def email_alert(task_instance, error_info): + try: + subject, html_content, _ = task_instance.get_email_subject_content(error_info) + email = sensor_work.execution_context.get('email') + + send_email(email, subject, html_content) + except Exception as e: # pylint: disable=broad-except + sensor_work.log.warning("Exception alerting email.") + sensor_work.log.exception(e, exc_info=True) + + def handle_failure(sensor_work, ti): + if sensor_work.execution_context.get('retries', None) and \ + ti.try_number <= ti.max_tries: + # retry + ti.state = State.UP_FOR_RETRY + if sensor_work.execution_context.get('email_on_retry', None) and \ + sensor_work.execution_context.get('email', None): + sensor_work.log.info("%s sending email alert for retry", sensor_work.ti_key) + email_alert(ti, error) + else: + ti.state = State.FAILED + if sensor_work.execution_context.get('email_on_failure', None) and \ + sensor_work.execution_context.get('email', None): + sensor_work.log.info("%s sending email alert for failure", sensor_work.ti_key) + email_alert(ti, error) + + try: + dag_id, task_id, execution_date = sensor_work.ti_key + TI = TaskInstance + SI = SensorInstance + sensor_instance = session.query(SI).filter( + SI.dag_id == dag_id, + SI.task_id == task_id, + SI.execution_date == execution_date) \ + .with_for_update() \ + .first() + + if sensor_instance.hashcode != sensor_work.hashcode: + # Return without setting state + return + + ti = session.query(TI).filter( + TI.dag_id == dag_id, + TI.task_id == task_id, + TI.execution_date == execution_date) \ + .with_for_update() \ + .first() + + if ti: + if ti.state == State.SENSING: + ti.hostname = self.hostname + handle_failure(sensor_work, ti) + + sensor_instance.state = State.FAILED + ti.end_date = timezone.utcnow() + ti.set_duration() + else: + sensor_instance.state = ti.state + session.merge(sensor_instance) + session.merge(ti) + session.commit() + + sensor_work.log.info("Task %s got an error: %s. Set the state to failed. Exit.", + str(sensor_work.ti_key), error) + sensor_work.close_sensor_logger() + + except AirflowException as e: + sensor_work.log.warning("Exception on failing %s", sensor_work.ti_key) + sensor_work.log.exception(e, exc_info=True) + + def _check_and_handle_ti_timeout(self, sensor_work): + """ + Check if a sensor task in smart sensor is timeout. Could be either sensor operator timeout + or general operator execution_timeout. + + :param sensor_work: SensorWork + """ + task_timeout = sensor_work.execution_context.get('timeout', self.timeout) + task_execution_timeout = sensor_work.execution_context.get('execution_timeout', None) + if task_execution_timeout: + task_timeout = min(task_timeout, task_execution_timeout) + + if (timezone.utcnow() - sensor_work.start_date).total_seconds() > task_timeout: + error = "Sensor Timeout" + sensor_work.log.exception(error) + self._retry_or_fail_task(sensor_work, error) + + def _handle_poke_exception(self, sensor_work): + """ + Fail task if accumulated exceptions exceeds retries. + + :param sensor_work: SensorWork + """ + sensor_exception = self.cached_sensor_exceptions.get(sensor_work.cache_key) + error = sensor_exception.exception_info + sensor_work.log.exception("Handling poke exception: %s", error) + + if sensor_exception.fail_current_run: + if sensor_exception.is_infra_failure: + sensor_work.log.exception("Task %s failed by infra failure in smart sensor.", + sensor_work.ti_key) + # There is a risk for sensor object cached in smart sensor keep throwing + # exception and cause an infra failure. To make sure the sensor tasks after + # retry will not fall into same object and have endless infra failure, + # we mark the sensor task after an infra failure so that it can be popped + # before next poke loop. + cache_key = sensor_work.cache_key + self.cached_dedup_works[cache_key].set_to_flush() + else: + sensor_work.log.exception("Task %s failed by exceptions.", sensor_work.ti_key) + self._retry_or_fail_task(sensor_work, error) + else: + sensor_work.log.info("Exception detected, retrying without failing current run.") + self._check_and_handle_ti_timeout(sensor_work) + + def _process_sensor_work_with_cached_state(self, sensor_work, state): + if state == PokeState.LANDED: + sensor_work.log.info("Task %s succeeded", str(sensor_work.ti_key)) + sensor_work.close_sensor_logger() + + if state == PokeState.NOT_LANDED: + # Handle timeout if connection valid but not landed yet + self._check_and_handle_ti_timeout(sensor_work) + elif state == PokeState.POKE_EXCEPTION: + self._handle_poke_exception(sensor_work) + + def _execute_sensor_work(self, sensor_work): + ti_key = sensor_work.ti_key + log = sensor_work.log or self.log + log.info("Sensing ti: %s", str(ti_key)) + log.info("Poking with arguments: %s", sensor_work.encoded_poke_context) + + cache_key = sensor_work.cache_key + if cache_key not in self.cached_dedup_works: + # create an empty cached_work for a new cache_key + self.cached_dedup_works[cache_key] = CachedPokeWork() + + cached_work = self.cached_dedup_works[cache_key] + + if cached_work.state is not None: + # Have a valid cached state, don't poke twice in certain time interval + self._process_sensor_work_with_cached_state(sensor_work, cached_work.state) + return + + try: + with timeout(seconds=self.poke_timeout): + if self.poke(sensor_work): + # Got a landed signal, mark all tasks waiting for this partition + cached_work.set_state(PokeState.LANDED) + + self._mark_multi_state(sensor_work.operator, + sensor_work.hashcode, + sensor_work.encoded_poke_context, + State.SUCCESS) + + log.info("Task %s succeeded", str(ti_key)) + sensor_work.close_sensor_logger() + else: + # Not landed yet. Handle possible timeout + cached_work.set_state(PokeState.NOT_LANDED) + self._check_and_handle_ti_timeout(sensor_work) + + self.cached_sensor_exceptions.pop(cache_key, None) + except Exception as e: # pylint: disable=broad-except + # The retry_infra_failure decorator inside hive_hooks will raise exception with + # is_infra_failure == True. Long poking timeout here is also considered an infra + # failure. Other exceptions should fail. + is_infra_failure = getattr(e, 'is_infra_failure', False) or isinstance(e, AirflowTaskTimeout) + exception_info = traceback.format_exc() + cached_work.set_state(PokeState.POKE_EXCEPTION) + + if cache_key in self.cached_sensor_exceptions: + self.cached_sensor_exceptions[cache_key].set_latest_exception( + exception_info, + is_infra_failure=is_infra_failure) + else: + self.cached_sensor_exceptions[cache_key] = SensorExceptionInfo( + exception_info, + is_infra_failure=is_infra_failure) + + self._handle_poke_exception(sensor_work) + + def flush_cached_sensor_poke_results(self): + """ + Flush outdated cached sensor states saved in previous loop. + + """ + for key, cached_work in self.cached_dedup_works.items(): + if cached_work.is_expired(): + self.cached_dedup_works.pop(key, None) + else: + cached_work.state = None + + for ti_key, sensor_exception in self.cached_sensor_exceptions.items(): + if sensor_exception.fail_current_run or sensor_exception.is_expired(): + self.cached_sensor_exceptions.pop(ti_key, None) + + def poke(self, sensor_work): + """ + Function that the sensors defined while deriving this class should + override. + + """ + cached_work = self.cached_dedup_works[sensor_work.cache_key] + if not cached_work.sensor_task: + init_args = dict(list(sensor_work.poke_context.items()) + + [('task_id', sensor_work.task_id)]) + operator_class = import_string(sensor_work.op_classpath) + cached_work.sensor_task = operator_class(**init_args) + + return cached_work.sensor_task.poke(sensor_work.poke_context) + + def _emit_loop_stats(self): + try: + count_poke = 0 + count_poke_success = 0 + count_poke_exception = 0 + count_exception_failures = 0 + count_infra_failure = 0 + for cached_work in self.cached_dedup_works.values(): + if cached_work.state is None: + continue + count_poke += 1 + if cached_work.state == PokeState.LANDED: + count_poke_success += 1 + elif cached_work.state == PokeState.POKE_EXCEPTION: + count_poke_exception += 1 + for cached_exception in self.cached_sensor_exceptions.values(): + if cached_exception.is_infra_failure and cached_exception.fail_current_run: + count_infra_failure += 1 + if cached_exception.fail_current_run: + count_exception_failures += 1 + + Stats.gauge("smart_sensor_operator.poked_tasks", count_poke) + Stats.gauge("smart_sensor_operator.poked_success", count_poke_success) + Stats.gauge("smart_sensor_operator.poked_exception", count_poke_exception) + Stats.gauge("smart_sensor_operator.exception_failures", count_exception_failures) + Stats.gauge("smart_sensor_operator.infra_failures", count_infra_failure) + except Exception as e: # pylint: disable=broad-except + self.log.exception("Exception at getting loop stats %s") + self.log.exception(e, exc_info=True) + + def execute(self, context): + started_at = timezone.utcnow() + + self.hostname = get_hostname() + while True: + poke_start_time = timezone.utcnow() + + self.flush_cached_sensor_poke_results() + + self._load_sensor_works() + self.log.info("Loaded %s sensor_works", len(self.sensor_works)) + Stats.gauge("smart_sensor_operator.loaded_tasks", len(self.sensor_works)) + + for sensor_work in self.sensor_works: + self._execute_sensor_work(sensor_work) + + duration = (timezone.utcnow() - poke_start_time).total_seconds() + + self.log.info("Taking %s to execute %s tasks.", duration, len(self.sensor_works)) + + Stats.timing("smart_sensor_operator.loop_duration", duration) + Stats.gauge("smart_sensor_operator.executed_tasks", len(self.sensor_works)) + self._emit_loop_stats() + + if duration < self.poke_interval: + sleep(self.poke_interval - duration) + if (timezone.utcnow() - started_at).total_seconds() > self.timeout: + self.log.info("Time is out for smart senosr.") + return + + def on_kill(self): + pass + + +if __name__ == '__main__': + SmartSensorOperator(task_id='test').execute({}) diff --git a/airflow/smart_sensor_dags/__init__.py b/airflow/smart_sensor_dags/__init__.py new file mode 100644 index 0000000..217e5db --- /dev/null +++ b/airflow/smart_sensor_dags/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/smart_sensor_dags/smart_sensor_group.py b/airflow/smart_sensor_dags/smart_sensor_group.py new file mode 100644 index 0000000..0187fa9 --- /dev/null +++ b/airflow/smart_sensor_dags/smart_sensor_group.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Smart sensor DAGs managing all smart sensor tasks.""" +from builtins import range +from datetime import timedelta + +from airflow.configuration import conf +from airflow.models import DAG +from airflow.sensors.smart_sensor_operator import SmartSensorOperator +from airflow.utils.dates import days_ago + +args = { + 'owner': 'airflow', +} + +num_smart_sensor_shard = conf.getint("smart_sensor", "shards") +shard_code_upper_limit = conf.getint('smart_sensor', 'shard_code_upper_limit') + +for i in range(num_smart_sensor_shard): + shard_min = (i * shard_code_upper_limit) / num_smart_sensor_shard + shard_max = ((i + 1) * shard_code_upper_limit) / num_smart_sensor_shard + + dag_id = 'smart_sensor_group_shard_{}'.format(i) + dag = DAG( + dag_id=dag_id, + default_args=args, + schedule_interval=timedelta(minutes=5), + concurrency=1, + max_active_runs=1, + catchup=False, + dagrun_timeout=timedelta(hours=24), + start_date=days_ago(2), + ) + + SmartSensorOperator( + task_id='smart_sensor_task', + dag=dag, + retries=100, + retry_delay=timedelta(seconds=10), + priority_weight=999, + shard_min=shard_min, + shard_max=shard_max, + poke_timeout=10, + smart_sensor_timeout=timedelta(hours=24).total_seconds(), + ) + + globals()[dag_id] = dag diff --git a/airflow/utils/file.py b/airflow/utils/file.py index d7e32ce..f2a5b9f 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -132,7 +132,9 @@ def find_path_from_directory( def list_py_file_paths(directory: str, safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True), - include_examples: Optional[bool] = None): + include_examples: Optional[bool] = None, + include_smart_sensor: Optional[bool] = + conf.getboolean('smart_sensor', 'use_smart_sensor')): """ Traverse a directory and look for Python files. @@ -145,6 +147,8 @@ def list_py_file_paths(directory: str, :type safe_mode: bool :param include_examples: include example DAGs :type include_examples: bool + :param include_smart_sensor: include smart sensor native control DAGs + :type include_examples: bool :return: a list of paths to Python files in the specified directory :rtype: list[unicode] """ @@ -152,15 +156,19 @@ def list_py_file_paths(directory: str, include_examples = conf.getboolean('core', 'LOAD_EXAMPLES') file_paths: List[str] = [] if directory is None: - return [] + file_paths = [] elif os.path.isfile(directory): - return [directory] + file_paths = [directory] elif os.path.isdir(directory): find_dag_file_paths(directory, file_paths, safe_mode) if include_examples: from airflow import example_dags example_dag_folder = example_dags.__path__[0] # type: ignore - file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False)) + file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False, False)) + if include_smart_sensor: + from airflow import smart_sensor_dags + smart_sensor_dag_folder = smart_sensor_dags.__path__[0] # type: ignore + file_paths.extend(list_py_file_paths(smart_sensor_dag_folder, safe_mode, False, False)) return file_paths diff --git a/airflow/utils/log/file_processor_handler.py b/airflow/utils/log/file_processor_handler.py index a552c11..d98ec85 100644 --- a/airflow/utils/log/file_processor_handler.py +++ b/airflow/utils/log/file_processor_handler.py @@ -76,7 +76,20 @@ class FileProcessorHandler(logging.Handler): self.handler.close() def _render_filename(self, filename): - filename = os.path.relpath(filename, self.dag_dir) + + # Airflow log path used to be generated by the relative path + # `os.path.relpath(filename, self.dag_dir)`, however the DAG `smart_sensor_group` and + # all DAGs in airflow source code are not located in the DAG dir as other DAGs. + # That will create a log filepath which is not under control since it could be outside + # of the log dir. The change here is to make sure the log path for DAGs in airflow code + # is alwasy inside the log dir as other DAGs. To be differentiate with regular DAGs, + # their logs will be in the `log_dir/native_dags`. + import airflow + airflow_directory = airflow.__path__[0] + if filename.startswith(airflow_directory): + filename = os.path.join("native_dags", os.path.relpath(filename, airflow_directory)) + else: + filename = os.path.relpath(filename, self.dag_dir) ctx = {} ctx['filename'] = filename diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 281a1a1..482f7ea 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -71,8 +71,15 @@ class FileTaskHandler(logging.Handler): def _render_filename(self, ti, try_number): if self.filename_jinja_template: - jinja_context = ti.get_template_context() - jinja_context['try_number'] = try_number + if hasattr(ti, 'task'): + jinja_context = ti.get_template_context() + jinja_context['try_number'] = try_number + else: + jinja_context = { + 'ti': ti, + 'ts': ti.execution_date.isoformat(), + 'try_number': try_number, + } return self.filename_jinja_template.render(**jinja_context) return self.filename_template.format(dag_id=ti.dag_id, @@ -80,6 +87,9 @@ class FileTaskHandler(logging.Handler): execution_date=ti.execution_date.isoformat(), try_number=try_number) + def _read_grouped_logs(self): + return False + def _read(self, ti, try_number, metadata=None): # pylint: disable=unused-argument """ Template method that contains custom logic of reading @@ -168,7 +178,7 @@ class FileTaskHandler(logging.Handler): it returns all logs separated by try_number :param metadata: log metadata, can be used for steaming log reading and auto-tailing. - :return: a list of logs + :return: a list of listed tuples which order log string by host """ # Task instance increments its try number when it starts to run. # So the log for a particular task try will only show up when @@ -180,7 +190,7 @@ class FileTaskHandler(logging.Handler): try_numbers = list(range(1, next_try)) elif try_number < 1: logs = [ - 'Error fetching the logs. Try number {} is invalid.'.format(try_number), + [('default_host', 'Error fetching the logs. Try number {} is invalid.'.format(try_number))], ] return logs else: @@ -190,7 +200,9 @@ class FileTaskHandler(logging.Handler): metadata_array = [{}] * len(try_numbers) for i, try_number_element in enumerate(try_numbers): log, metadata = self._read(task_instance, try_number_element, metadata) - logs[i] += log + # es_task_handler return logs grouped by host. wrap other handler returning log string + # with default/ empty host so that UI can render the response in the same way + logs[i] = log if self._read_grouped_logs() else [(task_instance.hostname, log)] metadata_array[i] = metadata return logs, metadata_array diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py index dbfe476..b54e901 100644 --- a/airflow/utils/log/log_reader.py +++ b/airflow/utils/log/log_reader.py @@ -84,7 +84,8 @@ class TaskLogReader: metadata.pop('offset', None) while 'end_of_log' not in metadata or not metadata['end_of_log']: logs, metadata = self.read_log_chunks(ti, current_try_number, metadata) - yield "\n".join(logs) + "\n" + for host, log in logs[0]: + yield "\n".join([host, log]) + "\n" @cached_property def log_handler(self): diff --git a/airflow/utils/state.py b/airflow/utils/state.py index b2abbff..4999cd8 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -43,6 +43,7 @@ class State: UP_FOR_RESCHEDULE = "up_for_reschedule" UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" + SENSING = "sensing" task_states = ( SUCCESS, @@ -55,6 +56,7 @@ class State: QUEUED, NONE, SCHEDULED, + SENSING, ) dag_states = ( @@ -76,6 +78,7 @@ class State: REMOVED: 'lightgrey', SCHEDULED: 'tan', NONE: 'lightblue', + SENSING: 'lightseagreen', } state_color.update(STATE_COLORS) # type: ignore @@ -97,6 +100,16 @@ class State: return 'black' @classmethod + def running(cls): + """ + A list of states indicating that a task is being executed. + """ + return [ + cls.RUNNING, + cls.SENSING + ] + + @classmethod def finished(cls): """ A list of states indicating that a task started and completed a @@ -120,7 +133,17 @@ class State: cls.SCHEDULED, cls.QUEUED, cls.RUNNING, + cls.SENSING, cls.SHUTDOWN, cls.UP_FOR_RETRY, - cls.UP_FOR_RESCHEDULE + cls.UP_FOR_RESCHEDULE, ] + + +class PokeState: + """ + Static class with poke states constants used in smart operator. + """ + LANDED = 'landed' + NOT_LANDED = 'not_landed' + POKE_EXCEPTION = 'poke_exception' diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css index cd9bedf..7f1b818 100644 --- a/airflow/www/static/css/graph.css +++ b/airflow/www/static/css/graph.css @@ -63,6 +63,10 @@ g.node.up_for_reschedule rect { stroke: turquoise; } +g.node.sensing rect { + stroke: lightseagreen; +} + g.node.queued rect { stroke: grey; } diff --git a/airflow/www/static/css/tree.css b/airflow/www/static/css/tree.css index b26f6ef..e69444a 100644 --- a/airflow/www/static/css/tree.css +++ b/airflow/www/static/css/tree.css @@ -83,6 +83,10 @@ rect.skipped { fill: pink; } +rect.sensing { + fill: lightseagreen; +} + .tooltip.in { opacity: 1; filter: alpha(opacity=100); diff --git a/airflow/www/templates/airflow/ti_log.html b/airflow/www/templates/airflow/ti_log.html index 1dcf6d3..7b9af69 100644 --- a/airflow/www/templates/airflow/ti_log.html +++ b/airflow/www/templates/airflow/ti_log.html @@ -44,7 +44,7 @@ <div role="tabpanel" class="tab-pane {{ 'active' if loop.last else '' }}" id="{{ loop.index }}"> <img id="loading-{{ loop.index }}" style="margin-top:0%; margin-left:50%; height:50px; width:50px; position: absolute;" alt="spinner" src="{{ url_for('static', filename='loading.gif') }}"> - <pre><code id="try-{{ loop.index }}" class="{{ 'wrap' if wrapped else '' }}">{{ log }}</code></pre> + <div id="log-group-{{ loop.index }}"></div> </div> {% endfor %} </div> @@ -122,16 +122,30 @@ if(auto_tailing && checkAutoTailingCondition()) { var should_scroll = true; } - // The message may contain HTML, so either have to escape it or write it as text. - var escaped_message = escapeHtml(res.message); // Detect urls var url_regex = /http(s)?:\/\/[\w\.\-]+(\.?:[\w\.\-]+)*([\/?#][\w\-\._~:/?#[\]@!\$&'\(\)\*\+,;=\.%]+)?/g; - var linkified_message = escaped_message.replace(url_regex, function(url) { - return "<a href=\"" + url + "\" target=\"_blank\">" + url + "</a>"; - }); - document.getElementById(`try-${try_number}`).innerHTML += linkified_message + "<br/>"; + res.message.forEach(function(item, index){ + var log_block_element_id = "try-" + try_number + "-" + item[0]; + var log_block = document.getElementById(log_block_element_id); + if (!log_block) { + log_div_block = document.createElement('div'); + log_pre_block = document.createElement('pre'); + log_div_block.appendChild(log_pre_block); + log_pre_block.innerHTML = "<code id=\"" + log_block_element_id + "\" ></code>"; + document.getElementById("log-group-" + try_number).appendChild(log_div_block); + log_block = document.getElementById(log_block_element_id); + } + + // The message may contain HTML, so either have to escape it or write it as text. + var escaped_message = escapeHtml(item[1]); + var linkified_message = escaped_message.replace(url_regex, function(url) { + return "<a href=\"" + url + "\" target=\"_blank\">" + url + "</a>"; + }); + log_block.innerHTML += linkified_message + "\n"; + }) + // Auto scroll window to the end if current window location is near the end. if(should_scroll) { scrollBottom(); diff --git a/docs/img/smart_sensor_architecture.png b/docs/img/smart_sensor_architecture.png new file mode 100644 index 0000000..4fdf3e9 Binary files /dev/null and b/docs/img/smart_sensor_architecture.png differ diff --git a/docs/img/smart_sensor_single_task_execute_flow.png b/docs/img/smart_sensor_single_task_execute_flow.png new file mode 100644 index 0000000..c3ec2e0 Binary files /dev/null and b/docs/img/smart_sensor_single_task_execute_flow.png differ diff --git a/docs/index.rst b/docs/index.rst index f532d9a..45ea799 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -95,6 +95,7 @@ Content lineage dag-serialization modules_management + smart-sensor changelog best-practices faq diff --git a/docs/logging-monitoring/metrics.rst b/docs/logging-monitoring/metrics.rst index 491f2fa..afe7801 100644 --- a/docs/logging-monitoring/metrics.rst +++ b/docs/logging-monitoring/metrics.rst @@ -112,6 +112,11 @@ Name Description ``pool.queued_slots.<pool_name>`` Number of queued slots in the pool ``pool.running_slots.<pool_name>`` Number of running slots in the pool ``pool.starving_tasks.<pool_name>`` Number of starving tasks in the pool +``smart_sensor_operator.poked_tasks`` Number of tasks poked by the smart sensor in the previous poking loop +``smart_sensor_operator.poked_success`` Number of newly succeeded tasks poked by the smart sensor in the previous poking loop +``smart_sensor_operator.poked_exception`` Number of exceptions in the previous smart sensor poking loop +``smart_sensor_operator.exception_failures`` Number of failures caused by exception in the previous smart sensor poking loop +``smart_sensor_operator.infra_failures`` Number of infrastructure failures in the previous smart sensor poking loop =================================================== ======================================================================== Timers diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 4f7c9e7..9b2f03f 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -131,6 +131,10 @@ Fundamentals * - :mod:`airflow.hooks.filesystem` - + * - :mod:`airflow.sensors.smart_sensor_operator` + - + + .. _Apache: ASF: Apache Software Foundation diff --git a/docs/smart-sensor.rst b/docs/smart-sensor.rst new file mode 100644 index 0000000..2664f8b --- /dev/null +++ b/docs/smart-sensor.rst @@ -0,0 +1,86 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + + +Smart Sensor +============ + +The smart sensor is a service (run by a builtin DAG) which greatly reduces airflow’s infrastructure +cost by consolidating some of the airflow long running light weight tasks. + +.. image:: img/smart_sensor_architecture.png + +Instead of using one process for each task, the main idea of the smart sensor service is to improve the +efficiency of these long running tasks by using centralized processes to execute those tasks in batches. + +To do that, we need to run a task in two steps, the first step is to serialize the task information +into the database; and the second step is to use a few centralized processes to execute the serialized +tasks in batches. + +In this way, we only need a handful of running processes. + +.. image:: img/smart_sensor_single_task_execute_flow.png + +The smart sensor service is supported in a new mode called “smart sensor mode”. In smart sensor mode, +instead of holding a long running process for each sensor and poking periodically, a sensor will only +store poke context at sensor_instance table and then exits with a ‘sensing’ state. + +When the smart sensor mode is enabled, a special set of builtin smart sensor DAGs +(named smart_sensor_group_shard_xxx) is created by the system; These DAGs contain ``SmartSensorOperator`` +task and manage the smart sensor jobs for the airflow cluster. The SmartSensorOperator task can fetch +hundreds of ‘sensing’ instances from sensor_instance table and poke on behalf of them in batches. +Users don’t need to change their existing DAGs. + +Enable/Disable Smart Sensor +--------------------------- + +Updating from a older version might need a schema change. If there is no ``sensor_instance`` table +in the DB, please make sure to run ``airflow db upgrade`` + +Add the following settings in the ``airflow.cfg``: + +.. code-block:: + + [smart_sensor] + use_smart_sensor = true + shard_code_upper_limit = 10000 + + # Users can change the following config based on their requirements + shards = 5 + sensor_enabled = NamedHivePartitionSensor, MetastorePartitionSensor + +* ``use_smart_sensor``: This config indicates if the smart sensor is enabled. +* ``shards``: This config indicates the number of concurrently running smart sensor jobs for + the airflow cluster. +* ``sensor_enabled``: This config is a list of sensor class names that will use the smart sensor. + The users use the same class names (e.g. HivePartitionSensor) in their DAGs and they don’t have + the control to use smart sensors or not, unless they exclude their tasks explicitly. + +Enabling/disabling the smart sensor service is a system level configuration change. +It is transparent to the individual users. Existing DAGs don't need to be changed for +enabling/disabling the smart sensor. Rotating centralized smart sensor tasks will not +cause any user’s sensor task failure. + +Support new operators in the smart sensor service +------------------------------------------------- + +* Define ``poke_context_fields`` as class attribute in the sensor. ``poke_context_fields`` + include all key names used for initializing a sensor object. +* In ``airflow.cfg``, add the new operator's classname to ``[smart_sensor] sensors_enabled``. + All supported sensors' classname should be comma separated. diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 2921a90..ff41609 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -158,7 +158,8 @@ class TestGetLog(unittest.TestCase): self.log_dir, self.DAG_ID, self.TASK_ID, self.default_time.replace(":", ".") ) self.assertEqual( - response.json['content'], f"*** Reading local file: {expected_filename}\nLog for testing." + response.json['content'], + f"[('', '*** Reading local file: {expected_filename}\\nLog for testing.')]", ) info = serializer.loads(response.json['continuation_token']) self.assertEqual(info, {'end_of_log': True}) @@ -182,7 +183,8 @@ class TestGetLog(unittest.TestCase): ) self.assertEqual(200, response.status_code) self.assertEqual( - response.data.decode('utf-8'), f"*** Reading local file: {expected_filename}\nLog for testing.\n" + response.data.decode('utf-8'), + f"\n*** Reading local file: {expected_filename}\nLog for testing.\n", ) @provide_session @@ -204,10 +206,10 @@ class TestGetLog(unittest.TestCase): def test_get_logs_with_metadata_as_download_large_file(self, session): self._create_dagrun(session) with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock: - first_return = (['1st line'], [{}]) - second_return = (['2nd line'], [{'end_of_log': False}]) - third_return = (['3rd line'], [{'end_of_log': True}]) - fourth_return = (['should never be read'], [{'end_of_log': True}]) + first_return = ([[('', '1st line')]], [{}]) + second_return = ([[('', '2nd line')]], [{'end_of_log': False}]) + third_return = ([[('', '3rd line')]], [{'end_of_log': True}]) + fourth_return = ([[('', 'should never be read')]], [{'end_of_log': True}]) read_mock.side_effect = [first_return, second_return, third_return, fourth_return] response = self.client.get( diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 00325fe..030ff3b 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -33,6 +33,7 @@ from mock import MagicMock, patch from parameterized import parameterized import airflow.example_dags +import airflow.smart_sensor_dags from airflow import settings from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -197,7 +198,9 @@ class TestDagFileProcessor(unittest.TestCase): executor = MockExecutor(do_update=True, parallelism=3) with create_session() as session: - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py")) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + include_smart_sensor=False) dag = self.create_test_dag() dag.clear() dagbag.bag_dag(dag=dag, root_dag=dag) @@ -1309,7 +1312,9 @@ class TestDagFileProcessorQueriesCount(unittest.TestCase): }), conf_vars({ ('scheduler', 'use_job_schedule'): 'True', }): - dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) + dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, + include_examples=False, + include_smart_sensor=False) processor = DagFileProcessor([], mock.MagicMock()) for expected_query_count in expected_query_counts: with assert_queries_count(expected_query_count): @@ -3298,6 +3303,19 @@ class TestSchedulerJob(unittest.TestCase): detected_files.add(file_path) self.assertEqual(detected_files, expected_files) + smart_sensor_dag_folder = airflow.smart_sensor_dags.__path__[0] + for root, _, files in os.walk(smart_sensor_dag_folder): + for file_name in files: + if (file_name.endswith('.py') or file_name.endswith('.zip')) and \ + file_name not in ['__init__.py']: + expected_files.add(os.path.join(root, file_name)) + detected_files.clear() + for file_path in list_py_file_paths(TEST_DAG_FOLDER, + include_examples=True, + include_smart_sensor=True): + detected_files.add(file_path) + self.assertEqual(detected_files, expected_files) + def test_reset_orphaned_tasks_nothing(self): """Try with nothing. """ scheduler = SchedulerJob() diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 9173ad1..c2fe045 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -728,7 +728,9 @@ class TestDagBag(unittest.TestCase): """ dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py") - dagbag = DagBag(dag_folder=dag_file) + dagbag = DagBag(dag_folder=dag_file, + include_smart_sensor=False, + include_examples=False) self.assertEqual(set(), set(dagbag.dag_ids)) expected_import_errors = { dag_file: ( @@ -747,7 +749,9 @@ class TestDagBag(unittest.TestCase): dag_file = os.path.join(TEST_DAGS_FOLDER, "test_with_non_default_owner.py") - dagbag = DagBag(dag_folder=dag_file) + dagbag = DagBag(dag_folder=dag_file, + include_examples=False, + include_smart_sensor=False) self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids)) self.assertEqual({}, dagbag.import_errors) diff --git a/tests/models/test_sensorinstance.py b/tests/models/test_sensorinstance.py new file mode 100644 index 0000000..168d97f --- /dev/null +++ b/tests/models/test_sensorinstance.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow.models import SensorInstance +from airflow.providers.apache.hive.sensors.named_hive_partition import NamedHivePartitionSensor +from airflow.sensors.python import PythonSensor + + +class SensorInstanceTest(unittest.TestCase): + + def test_get_classpath(self): + # Test the classpath in/out airflow + obj1 = NamedHivePartitionSensor( + partition_names=['test_partition'], + task_id='meta_partition_test_1') + obj1_classpath = SensorInstance.get_classpath(obj1) + obj1_importpath = "airflow.providers.apache.hive." \ + "sensors.named_hive_partition.NamedHivePartitionSensor" + + self.assertEqual(obj1_classpath, obj1_importpath) + + def test_callable(): + return + obj3 = PythonSensor(python_callable=test_callable, + task_id='python_sensor_test') + obj3_classpath = SensorInstance.get_classpath(obj3) + obj3_importpath = "airflow.sensors.python.PythonSensor" + + self.assertEqual(obj3_classpath, obj3_importpath) diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index 7509be2..9eedd6b 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -127,7 +127,10 @@ class TestCloudwatchTaskHandler(unittest.TestCase): ) self.assertEqual( self.cloudwatch_task_handler.read(self.ti), - ([expected.format(self.remote_log_group, self.remote_log_stream)], [{'end_of_log': True}]), + ( + [[('', expected.format(self.remote_log_group, self.remote_log_stream))]], + [{'end_of_log': True}], + ), ) def test_read_wrong_log_stream(self): @@ -149,7 +152,7 @@ class TestCloudwatchTaskHandler(unittest.TestCase): self.assertEqual( self.cloudwatch_task_handler.read(self.ti), ( - [msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg)], + [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]], [{'end_of_log': True}], ), ) @@ -173,7 +176,7 @@ class TestCloudwatchTaskHandler(unittest.TestCase): self.assertEqual( self.cloudwatch_task_handler.read(self.ti), ( - [msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg)], + [[('', msg_template.format(self.remote_log_group, self.remote_log_stream, error_msg))]], [{'end_of_log': True}], ), ) diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py index 4c737e0..2c0eebf 100644 --- a/tests/providers/amazon/aws/log/test_s3_task_handler.py +++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py @@ -128,20 +128,19 @@ class TestS3TaskHandler(unittest.TestCase): def test_read(self): self.conn.put_object(Bucket='bucket', Key=self.remote_log_key, Body=b'Log line\n') + log, metadata = self.s3_task_handler.read(self.ti) self.assertEqual( - self.s3_task_handler.read(self.ti), - ( - ['*** Reading remote log from s3://bucket/remote/log/location/1.log.\nLog line\n\n'], - [{'end_of_log': True}], - ), + log[0][0][-1], + '*** Reading remote log from s3://bucket/remote/log/location/1.log.\n' 'Log line\n\n', ) + self.assertEqual(metadata, [{'end_of_log': True}]) def test_read_when_s3_log_missing(self): log, metadata = self.s3_task_handler.read(self.ti) self.assertEqual(1, len(log)) self.assertEqual(len(log), len(metadata)) - self.assertIn('*** Log file does not exist:', log[0]) + self.assertIn('*** Log file does not exist:', log[0][0][-1]) self.assertEqual({'end_of_log': True}, metadata[0]) def test_read_raises_return_error(self): diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index deec540..5423df3 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -114,7 +114,8 @@ class TestElasticsearchTaskHandler(unittest.TestCase): self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual(self.test_message, logs[0]) + self.assertEqual(len(logs[0]), 1) + self.assertEqual(self.test_message, logs[0][0][-1]) self.assertFalse(metadatas[0]['end_of_log']) self.assertEqual('1', metadatas[0]['offset']) self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) > ts) @@ -134,7 +135,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual(self.test_message, logs[0]) + self.assertEqual(self.test_message, logs[0][0][-1]) self.assertNotEqual(another_test_message, logs[0]) self.assertFalse(metadatas[0]['end_of_log']) @@ -145,7 +146,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): logs, metadatas = self.es_task_handler.read(self.ti, 1) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual(self.test_message, logs[0]) + self.assertEqual(self.test_message, logs[0][0][-1]) self.assertFalse(metadatas[0]['end_of_log']) self.assertEqual('1', metadatas[0]['offset']) self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) < pendulum.now()) @@ -161,7 +162,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual([''], logs) + self.assertEqual([[]], logs) self.assertFalse(metadatas[0]['end_of_log']) self.assertEqual('0', metadatas[0]['offset']) # last_log_timestamp won't change if no log lines read. @@ -172,7 +173,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): logs, metadatas = self.es_task_handler.read(self.ti, 1, {}) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual(self.test_message, logs[0]) + self.assertEqual(self.test_message, logs[0][0][-1]) self.assertFalse(metadatas[0]['end_of_log']) # offset should be initialized to 0 if not provided. self.assertEqual('1', metadatas[0]['offset']) @@ -185,7 +186,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): logs, metadatas = self.es_task_handler.read(self.ti, 1, {'end_of_log': False}) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual([''], logs) + self.assertEqual([[]], logs) self.assertFalse(metadatas[0]['end_of_log']) # offset should be initialized to 0 if not provided. self.assertEqual('0', metadatas[0]['offset']) @@ -202,7 +203,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual([''], logs) + self.assertEqual([[]], logs) self.assertTrue(metadatas[0]['end_of_log']) # offset should be initialized to 0 if not provided. self.assertEqual('0', metadatas[0]['offset']) @@ -217,7 +218,8 @@ class TestElasticsearchTaskHandler(unittest.TestCase): ) self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual(self.test_message, logs[0]) + self.assertEqual(len(logs[0]), 1) + self.assertEqual(self.test_message, logs[0][0][-1]) self.assertFalse(metadatas[0]['end_of_log']) self.assertTrue(metadatas[0]['download_logs']) self.assertEqual('1', metadatas[0]['offset']) @@ -234,7 +236,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase): self.assertEqual(1, len(logs)) self.assertEqual(len(logs), len(metadatas)) - self.assertEqual([''], logs) + self.assertEqual([[]], logs) self.assertFalse(metadatas[0]['end_of_log']) self.assertEqual('0', metadatas[0]['offset']) diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py index 9a51ec5..d5bf169 100644 --- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py +++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py @@ -101,7 +101,15 @@ class TestWasbTaskHandler(unittest.TestCase): self.assertEqual( self.wasb_task_handler.read(self.ti), ( - ['*** Reading remote log from wasb://container/remote/log/location/1.log.\nLog line\n'], + [ + [ + ( + '', + '*** Reading remote log from wasb://container/remote/log/location/1.log.\n' + 'Log line\n', + ) + ] + ], [{'end_of_log': True}], ), ) diff --git a/tests/sensors/test_smart_sensor_operator.py b/tests/sensors/test_smart_sensor_operator.py new file mode 100644 index 0000000..28eeee6 --- /dev/null +++ b/tests/sensors/test_smart_sensor_operator.py @@ -0,0 +1,326 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import logging +import os +import time +import unittest + +from freezegun import freeze_time +from mock import Mock + +from airflow import DAG, settings +from airflow.configuration import conf +from airflow.models import DagRun, SensorInstance, TaskInstance +from airflow.operators.dummy_operator import DummyOperator +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.sensors.smart_sensor_operator import SmartSensorOperator +from airflow.utils import timezone +from airflow.utils.state import State + +DEFAULT_DATE = timezone.datetime(2015, 1, 1) +TEST_DAG_ID = 'unit_test_dag' +TEST_SENSOR_DAG_ID = 'unit_test_sensor_dag' +DUMMY_OP = 'dummy_op' +SMART_OP = 'smart_op' +SENSOR_OP = 'sensor_op' + + +class DummySmartSensor(SmartSensorOperator): + def __init__(self, + shard_max=conf.getint('smart_sensor', 'shard_code_upper_limit'), + shard_min=0, + **kwargs): + super(DummySmartSensor, self).__init__(shard_min=shard_min, + shard_max=shard_max, + **kwargs) + + +class DummySensor(BaseSensorOperator): + poke_context_fields = ('input_field', 'return_value') + exec_fields = ('soft_fail', 'execution_timeout', 'timeout') + + def __init__(self, input_field='test', return_value=False, **kwargs): + super(DummySensor, self).__init__(**kwargs) + self.input_field = input_field + self.return_value = return_value + + def poke(self, context): + return context.get('return_value', False) + + def is_smart_sensor_compatible(self): + return not self.on_failure_callback + + +class SmartSensorTest(unittest.TestCase): + def setUp(self): + os.environ['AIRFLOW__SMART_SENSER__USE_SMART_SENSOR'] = 'true' + os.environ['AIRFLOW__SMART_SENSER__SENSORS_ENABLED'] = 'DummySmartSensor' + + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + self.dag = DAG(TEST_DAG_ID, default_args=args) + self.sensor_dag = DAG(TEST_SENSOR_DAG_ID, default_args=args) + self.log = logging.getLogger('BaseSmartTest') + + session = settings.Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.query(SensorInstance).delete() + session.commit() + + def tearDown(self): + session = settings.Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.query(SensorInstance).delete() + session.commit() + + os.environ.pop('AIRFLOW__SMART_SENSER__USE_SMART_SENSOR') + os.environ.pop('AIRFLOW__SMART_SENSER__SENSORS_ENABLED') + + def _make_dag_run(self): + return self.dag.create_dagrun( + run_id='manual__' + TEST_DAG_ID, + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + def _make_sensor_dag_run(self): + return self.sensor_dag.create_dagrun( + run_id='manual__' + TEST_SENSOR_DAG_ID, + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + def _make_sensor(self, return_value, **kwargs): + poke_interval = 'poke_interval' + timeout = 'timeout' + if poke_interval not in kwargs: + kwargs[poke_interval] = 0 + if timeout not in kwargs: + kwargs[timeout] = 0 + + sensor = DummySensor( + task_id=SENSOR_OP, + return_value=return_value, + dag=self.sensor_dag, + **kwargs + ) + + return sensor + + def _make_sensor_instance(self, index, return_value, **kwargs): + poke_interval = 'poke_interval' + timeout = 'timeout' + if poke_interval not in kwargs: + kwargs[poke_interval] = 0 + if timeout not in kwargs: + kwargs[timeout] = 0 + + task_id = SENSOR_OP + str(index) + sensor = DummySensor( + task_id=task_id, + return_value=return_value, + dag=self.sensor_dag, + **kwargs + ) + + ti = TaskInstance(task=sensor, execution_date=DEFAULT_DATE) + + return ti + + def _make_smart_operator(self, index, **kwargs): + poke_interval = 'poke_interval' + smart_sensor_timeout = 'smart_sensor_timeout' + if poke_interval not in kwargs: + kwargs[poke_interval] = 0 + if smart_sensor_timeout not in kwargs: + kwargs[smart_sensor_timeout] = 0 + + smart_task = DummySmartSensor( + task_id=SMART_OP + "_" + str(index), + dag=self.dag, + **kwargs + ) + + dummy_op = DummyOperator( + task_id=DUMMY_OP, + dag=self.dag + ) + dummy_op.set_upstream(smart_task) + return smart_task + + @classmethod + def _run(cls, task): + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_load_sensor_works(self): + # Mock two sensor tasks return True and one return False + # The hashcode for si1 and si2 should be same. Test dedup on these two instances + si1 = self._make_sensor_instance(1, True) + si2 = self._make_sensor_instance(2, True) + si3 = self._make_sensor_instance(3, False) + + # Confirm initial state + smart = self._make_smart_operator(0) + smart.flush_cached_sensor_poke_results() + self.assertEqual(len(smart.cached_dedup_works), 0) + self.assertEqual(len(smart.cached_sensor_exceptions), 0) + + si1.run(ignore_all_deps=True) + # Test single sensor + smart._load_sensor_works() + self.assertEqual(len(smart.sensor_works), 1) + self.assertEqual(len(smart.cached_dedup_works), 0) + self.assertEqual(len(smart.cached_sensor_exceptions), 0) + + si2.run(ignore_all_deps=True) + si3.run(ignore_all_deps=True) + + # Test multiple sensors with duplication + smart._load_sensor_works() + self.assertEqual(len(smart.sensor_works), 3) + self.assertEqual(len(smart.cached_dedup_works), 0) + self.assertEqual(len(smart.cached_sensor_exceptions), 0) + + def test_execute_single_task_with_dup(self): + sensor_dr = self._make_sensor_dag_run() + si1 = self._make_sensor_instance(1, True) + si2 = self._make_sensor_instance(2, True) + si3 = self._make_sensor_instance(3, False, timeout=0) + + si1.run(ignore_all_deps=True) + si2.run(ignore_all_deps=True) + si3.run(ignore_all_deps=True) + + smart = self._make_smart_operator(0) + smart.flush_cached_sensor_poke_results() + + smart._load_sensor_works() + self.assertEqual(len(smart.sensor_works), 3) + + for sensor_work in smart.sensor_works: + _, task_id, _ = sensor_work.ti_key + if task_id == SENSOR_OP + "1": + smart._execute_sensor_work(sensor_work) + break + + self.assertEqual(len(smart.cached_dedup_works), 1) + + tis = sensor_dr.get_task_instances() + for ti in tis: + if ti.task_id == SENSOR_OP + "1": + self.assertEqual(ti.state, State.SUCCESS) + if ti.task_id == SENSOR_OP + "2": + self.assertEqual(ti.state, State.SUCCESS) + if ti.task_id == SENSOR_OP + "3": + self.assertEqual(ti.state, State.SENSING) + + for sensor_work in smart.sensor_works: + _, task_id, _ = sensor_work.ti_key + if task_id == SENSOR_OP + "2": + smart._execute_sensor_work(sensor_work) + break + + self.assertEqual(len(smart.cached_dedup_works), 1) + + time.sleep(1) + for sensor_work in smart.sensor_works: + _, task_id, _ = sensor_work.ti_key + if task_id == SENSOR_OP + "3": + smart._execute_sensor_work(sensor_work) + break + + self.assertEqual(len(smart.cached_dedup_works), 2) + + tis = sensor_dr.get_task_instances() + for ti in tis: + # Timeout=0, the Failed poke lead to task fail + if ti.task_id == SENSOR_OP + "3": + self.assertEqual(ti.state, State.FAILED) + + def test_smart_operator_timeout(self): + sensor_dr = self._make_sensor_dag_run() + si1 = self._make_sensor_instance(1, False, timeout=10) + smart = self._make_smart_operator(0, poke_interval=6) + smart.poke = Mock(side_effect=[False, False, False, False]) + + date1 = timezone.utcnow() + with freeze_time(date1): + si1.run(ignore_all_deps=True) + smart.flush_cached_sensor_poke_results() + smart._load_sensor_works() + + for sensor_work in smart.sensor_works: + smart._execute_sensor_work(sensor_work) + + # Before timeout the state should be SENSING + sis = sensor_dr.get_task_instances() + for sensor_instance in sis: + if sensor_instance.task_id == SENSOR_OP + "1": + self.assertEqual(sensor_instance.state, State.SENSING) + + date2 = date1 + datetime.timedelta(seconds=smart.poke_interval) + with freeze_time(date2): + smart.flush_cached_sensor_poke_results() + smart._load_sensor_works() + + for sensor_work in smart.sensor_works: + smart._execute_sensor_work(sensor_work) + + sis = sensor_dr.get_task_instances() + for sensor_instance in sis: + if sensor_instance.task_id == SENSOR_OP + "1": + self.assertEqual(sensor_instance.state, State.SENSING) + + date3 = date2 + datetime.timedelta(seconds=smart.poke_interval) + with freeze_time(date3): + smart.flush_cached_sensor_poke_results() + smart._load_sensor_works() + + for sensor_work in smart.sensor_works: + smart._execute_sensor_work(sensor_work) + + sis = sensor_dr.get_task_instances() + for sensor_instance in sis: + if sensor_instance.task_id == SENSOR_OP + "1": + self.assertEqual(sensor_instance.state, State.FAILED) + + def test_register_in_sensor_service(self): + si1 = self._make_sensor_instance(1, True) + si1.run(ignore_all_deps=True) + self.assertEqual(si1.state, State.SENSING) + + session = settings.Session() + + SI = SensorInstance + sensor_instance = session.query(SI).filter( + SI.dag_id == si1.dag_id, + SI.task_id == si1.task_id, + SI.execution_date == si1.execution_date) \ + .first() + + self.assertIsNotNone(sensor_instance) + self.assertEqual(sensor_instance.state, State.SENSING) + self.assertEqual(sensor_instance.operator, "DummySensor") diff --git a/tests/test_config_templates.py b/tests/test_config_templates.py index c8fea0b..7b513d9 100644 --- a/tests/test_config_templates.py +++ b/tests/test_config_templates.py @@ -53,7 +53,8 @@ DEFAULT_AIRFLOW_SECTIONS = [ 'kubernetes_node_selectors', 'kubernetes_environment_variables', 'kubernetes_secrets', - 'kubernetes_labels' + 'kubernetes_labels', + 'smart_sensor' ] DEFAULT_TEST_SECTIONS = [ diff --git a/tests/utils/log/test_log_reader.py b/tests/utils/log/test_log_reader.py index 0fe2a7a..fc65496 100644 --- a/tests/utils/log/test_log_reader.py +++ b/tests/utils/log/test_log_reader.py @@ -105,11 +105,12 @@ class TestLogView(unittest.TestCase): self.assertEqual( [ - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" - f"try_number=1.\n" + ('', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" + f"try_number=1.\n") ], - logs, + logs[0], ) self.assertEqual({"end_of_log": True}, metadatas) @@ -119,15 +120,18 @@ class TestLogView(unittest.TestCase): self.assertEqual( [ - "*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" - "try_number=1.\n", - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n" - f"try_number=2.\n", - f"*** Reading local file: " - f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n" - f"try_number=3.\n", + [('', + "*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" + "try_number=1.\n")], + [('', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n" + f"try_number=2.\n")], + [('', + f"*** Reading local file: " + f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n" + f"try_number=3.\n")], ], logs, ) @@ -139,7 +143,7 @@ class TestLogView(unittest.TestCase): self.assertEqual( [ - "*** Reading local file: " + "\n*** Reading local file: " f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" "try_number=1.\n" "\n" @@ -152,15 +156,15 @@ class TestLogView(unittest.TestCase): stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None, metadata={}) self.assertEqual( [ - "*** Reading local file: " + "\n*** Reading local file: " f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n" "try_number=1.\n" "\n", - "*** Reading local file: " + "\n*** Reading local file: " f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n" "try_number=2.\n" "\n", - "*** Reading local file: " + "\n*** Reading local file: " f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n" "try_number=3.\n" "\n", @@ -170,15 +174,15 @@ class TestLogView(unittest.TestCase): @mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") def test_read_log_stream_should_support_multiple_chunks(self, mock_read): - first_return = (["1st line"], [{}]) - second_return = (["2nd line"], [{"end_of_log": False}]) - third_return = (["3rd line"], [{"end_of_log": True}]) - fourth_return = (["should never be read"], [{"end_of_log": True}]) + first_return = ([[('', "1st line")]], [{}]) + second_return = ([[('', "2nd line")]], [{"end_of_log": False}]) + third_return = ([[('', "3rd line")]], [{"end_of_log": True}]) + fourth_return = ([[('', "should never be read")]], [{"end_of_log": True}]) mock_read.side_effect = [first_return, second_return, third_return, fourth_return] task_log_reader = TaskLogReader() log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1, metadata={}) - self.assertEqual(["1st line\n", "2nd line\n", "3rd line\n"], list(log_stream)) + self.assertEqual(["\n1st line\n", "\n2nd line\n", "\n3rd line\n"], list(log_stream)) mock_read.assert_has_calls( [ @@ -191,15 +195,15 @@ class TestLogView(unittest.TestCase): @mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") def test_read_log_stream_should_read_each_try_in_turn(self, mock_read): - first_return = (["try_number=1."], [{"end_of_log": True}]) - second_return = (["try_number=2."], [{"end_of_log": True}]) - third_return = (["try_number=3."], [{"end_of_log": True}]) - fourth_return = (["should never be read"], [{"end_of_log": True}]) + first_return = ([[('', "try_number=1.")]], [{"end_of_log": True}]) + second_return = ([[('', "try_number=2.")]], [{"end_of_log": True}]) + third_return = ([[('', "try_number=3.")]], [{"end_of_log": True}]) + fourth_return = ([[('', "should never be read")]], [{"end_of_log": True}]) mock_read.side_effect = [first_return, second_return, third_return, fourth_return] task_log_reader = TaskLogReader() log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None, metadata={}) - self.assertEqual(['try_number=1.\n', 'try_number=2.\n', 'try_number=3.\n'], list(log_stream)) + self.assertEqual(['\ntry_number=1.\n', '\ntry_number=2.\n', '\ntry_number=3.\n'], list(log_stream)) mock_read.assert_has_calls( [ diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 277feb3..f1ef957 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -105,7 +105,7 @@ class TestFileTaskLogHandler(unittest.TestCase): # We should expect our log line from the callable above to appear in # the logs we read back self.assertRegex( - logs[0], + logs[0][0][-1], target_re, "Logs were " + str(logs) ) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index f276dd8..8060ab8 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -506,9 +506,9 @@ class TestAirflowBaseViews(TestBase): val_state_color_mapping = 'const STATE_COLOR = {"failed": "red", ' \ '"null": "lightblue", "queued": "gray", ' \ '"removed": "lightgrey", "running": "lime", ' \ - '"scheduled": "tan", "shutdown": "blue", ' \ - '"skipped": "pink", "success": "green", ' \ - '"up_for_reschedule": "turquoise", ' \ + '"scheduled": "tan", "sensing": "lightseagreen", ' \ + '"shutdown": "blue", "skipped": "pink", ' \ + '"success": "green", "up_for_reschedule": "turquoise", ' \ '"up_for_retry": "gold", "upstream_failed": "orange"};' self.check_content_in_response(val_state_color_mapping, resp) @@ -1163,9 +1163,9 @@ class TestLogView(TestBase): self.assertEqual(response.status_code, 200) self.assertIn('Log by attempts', response.data.decode('utf-8')) for num in range(1, expected_num_logs_visible + 1): - self.assertIn('try-{}'.format(num), response.data.decode('utf-8')) - self.assertNotIn('try-0', response.data.decode('utf-8')) - self.assertNotIn('try-{}'.format(expected_num_logs_visible + 1), response.data.decode('utf-8')) + self.assertIn('log-group-{}'.format(num), response.data.decode('utf-8')) + self.assertNotIn('log-group-0', response.data.decode('utf-8')) + self.assertNotIn('log-group-{}'.format(expected_num_logs_visible + 1), response.data.decode('utf-8')) def test_get_logs_with_metadata_as_download_file(self): url_template = "get_logs_with_metadata?dag_id={}&" \ @@ -1191,10 +1191,10 @@ class TestLogView(TestBase): def test_get_logs_with_metadata_as_download_large_file(self): with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock: - first_return = (['1st line'], [{}]) - second_return = (['2nd line'], [{'end_of_log': False}]) - third_return = (['3rd line'], [{'end_of_log': True}]) - fourth_return = (['should never be read'], [{'end_of_log': True}]) + first_return = ([[('default_log', '1st line')]], [{}]) + second_return = ([[('default_log', '2nd line')]], [{'end_of_log': False}]) + third_return = ([[('default_log', '3rd line')]], [{'end_of_log': True}]) + fourth_return = ([[('default_log', 'should never be read')]], [{'end_of_log': True}]) read_mock.side_effect = [first_return, second_return, third_return, fourth_return] url_template = "get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ @@ -1300,7 +1300,7 @@ class TestLogView(TestBase): response = self.client.get(url) self.assertIn('message', response.json) self.assertIn('metadata', response.json) - self.assertIn('Log for testing.', response.json['message']) + self.assertIn('Log for testing.', response.json['message'][0][1]) self.assertEqual(200, response.status_code) @mock.patch("airflow.www.views.TaskLogReader")