This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new ac943c9 [AIRFLOW-3964][AIP-17] Consolidate and de-dup sensor tasks
using Smart Sensor (#5499)
ac943c9 is described below
commit ac943c9e18f75259d531dbda8c51e650f57faa4c
Author: Yingbo Wang <[email protected]>
AuthorDate: Tue Sep 8 14:47:59 2020 -0700
[AIRFLOW-3964][AIP-17] Consolidate and de-dup sensor tasks using Smart
Sensor (#5499)
Co-authored-by: Yingbo Wang <[email protected]>
---
airflow/config_templates/config.yml | 33 +
airflow/config_templates/default_airflow.cfg | 15 +
airflow/exceptions.py | 7 +
airflow/jobs/scheduler_job.py | 7 +-
.../e38be357a868_update_schema_for_smart_sensor.py | 92 +++
airflow/models/__init__.py | 1 +
airflow/models/baseoperator.py | 7 +
airflow/models/dagbag.py | 12 +-
airflow/models/sensorinstance.py | 166 +++++
airflow/models/taskinstance.py | 121 +++-
.../apache/hive/sensors/metastore_partition.py | 1 +
.../apache/hive/sensors/named_hive_partition.py | 10 +
.../providers/elasticsearch/log/es_task_handler.py | 40 +-
airflow/sensors/base_sensor_operator.py | 74 +-
airflow/sensors/smart_sensor_operator.py | 764 +++++++++++++++++++++
airflow/smart_sensor_dags/__init__.py | 17 +
airflow/smart_sensor_dags/smart_sensor_group.py | 63 ++
airflow/utils/file.py | 16 +-
airflow/utils/log/file_processor_handler.py | 15 +-
airflow/utils/log/file_task_handler.py | 22 +-
airflow/utils/log/log_reader.py | 3 +-
airflow/utils/state.py | 25 +-
airflow/www/static/css/graph.css | 4 +
airflow/www/static/css/tree.css | 4 +
airflow/www/templates/airflow/ti_log.html | 28 +-
docs/img/smart_sensor_architecture.png | Bin 0 -> 80325 bytes
docs/img/smart_sensor_single_task_execute_flow.png | Bin 0 -> 75462 bytes
docs/index.rst | 1 +
docs/logging-monitoring/metrics.rst | 5 +
docs/operators-and-hooks-ref.rst | 4 +
docs/smart-sensor.rst | 86 +++
tests/api_connexion/endpoints/test_log_endpoint.py | 14 +-
tests/jobs/test_scheduler_job.py | 22 +-
tests/models/test_dagbag.py | 8 +-
tests/models/test_sensorinstance.py | 46 ++
.../amazon/aws/log/test_cloudwatch_task_handler.py | 9 +-
.../amazon/aws/log/test_s3_task_handler.py | 11 +-
.../elasticsearch/log/test_es_task_handler.py | 20 +-
.../microsoft/azure/log/test_wasb_task_handler.py | 10 +-
tests/sensors/test_smart_sensor_operator.py | 326 +++++++++
tests/test_config_templates.py | 3 +-
tests/utils/log/test_log_reader.py | 58 +-
tests/utils/test_log_handlers.py | 2 +-
tests/www/test_views.py | 22 +-
44 files changed, 2062 insertions(+), 132 deletions(-)
diff --git a/airflow/config_templates/config.yml
b/airflow/config_templates/config.yml
index 10b8987..25b24da 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -2337,3 +2337,36 @@
to identify the task.
Should be supplied in the format: ``key = value``
options: []
+- name: smart_sensor
+ description: ~
+ options:
+ - name: use_smart_sensor
+ description: |
+ When `use_smart_sensor` is True, Airflow redirects multiple qualified
sensor tasks to
+ smart sensor task.
+ version_added: ~
+ type: boolean
+ example: ~
+ default: "False"
+ - name: shard_code_upper_limit
+ description: |
+ `shard_code_upper_limit` is the upper limit of `shard_code` value. The
`shard_code` is generated
+ by `hashcode % shard_code_upper_limit`.
+ version_added: ~
+ type: int
+ example: ~
+ default: "10000"
+ - name: shards
+ description: |
+ The number of running smart sensor processes for each service.
+ version_added: ~
+ type: int
+ example: ~
+ default: "5"
+ - name: sensors_enabled
+ description: |
+ comma separated sensor classes support in smart_sensor.
+ version_added: ~
+ type: string
+ example: ~
+ default: "NamedHivePartitionSensor"
diff --git a/airflow/config_templates/default_airflow.cfg
b/airflow/config_templates/default_airflow.cfg
index d5c5262..24ba5a1 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -1179,3 +1179,18 @@ worker_resources =
# The worker pods will be given these static labels, as well as some
additional dynamic labels
# to identify the task.
# Should be supplied in the format: ``key = value``
+
+[smart_sensor]
+# When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor
tasks to
+# smart sensor task.
+use_smart_sensor = False
+
+# `shard_code_upper_limit` is the upper limit of `shard_code` value. The
`shard_code` is generated
+# by `hashcode % shard_code_upper_limit`.
+shard_code_upper_limit = 10000
+
+# The number of running smart sensor processes for each service.
+shards = 5
+
+# comma separated sensor classes support in smart_sensor.
+sensors_enabled = NamedHivePartitionSensor
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 22d20c5..0406ce7 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -63,6 +63,13 @@ class AirflowRescheduleException(AirflowException):
self.reschedule_date = reschedule_date
+class AirflowSmartSensorException(AirflowException):
+ """
+ Raise after the task register itself in the smart sensor service
+ It should exit without failing a task
+ """
+
+
class InvalidStatsNameException(AirflowException):
"""Raise when name of the stats is invalid"""
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 422f5a9..cb43f8d 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -862,7 +862,7 @@ class DagFileProcessor(LoggingMixin):
self.log.info("Processing file %s for tasks to queue", file_path)
try:
- dagbag = DagBag(file_path, include_examples=False)
+ dagbag = DagBag(file_path, include_examples=False,
include_smart_sensor=False)
except Exception: # pylint: disable=broad-except
self.log.exception("Failed at reloading the DAG file %s",
file_path)
Stats.incr('dag_file_refresh_error', 1, 1)
@@ -1743,7 +1743,10 @@ class SchedulerJob(BaseJob): # pylint:
disable=too-many-instance-attributes
# NONE so we don't try to re-run it.
self._change_state_for_tis_without_dagrun(
simple_dag_bag=simple_dag_bag,
- old_states=[State.QUEUED, State.SCHEDULED,
State.UP_FOR_RESCHEDULE],
+ old_states=[State.QUEUED,
+ State.SCHEDULED,
+ State.UP_FOR_RESCHEDULE,
+ State.SENSING],
new_state=State.NONE
)
self._execute_task_instances(simple_dag_bag)
diff --git
a/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
new file mode 100644
index 0000000..27227ae
--- /dev/null
+++ b/airflow/migrations/versions/e38be357a868_update_schema_for_smart_sensor.py
@@ -0,0 +1,92 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add sensor_instance table
+
+Revision ID: e38be357a868
+Revises: 939bb1e647c8
+Create Date: 2019-06-07 04:03:17.003939
+
+"""
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy import func
+from sqlalchemy.dialects import mysql
+
+# revision identifiers, used by Alembic.
+revision = 'e38be357a868'
+down_revision = 'da3f683c3a5a'
+branch_labels = None
+depends_on = None
+
+
+def mssql_timestamp(): # noqa: D103
+ return sa.DateTime()
+
+
+def mysql_timestamp(): # noqa: D103
+ return mysql.TIMESTAMP(fsp=6)
+
+
+def sa_timestamp(): # noqa: D103
+ return sa.TIMESTAMP(timezone=True)
+
+
+def upgrade(): # noqa: D103
+
+ conn = op.get_bind()
+ if conn.dialect.name == 'mysql':
+ timestamp = mysql_timestamp
+ elif conn.dialect.name == 'mssql':
+ timestamp = mssql_timestamp
+ else:
+ timestamp = sa_timestamp
+
+ op.create_table(
+ 'sensor_instance',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('task_id', sa.String(length=250), nullable=False),
+ sa.Column('dag_id', sa.String(length=250), nullable=False),
+ sa.Column('execution_date', timestamp(), nullable=False),
+ sa.Column('state', sa.String(length=20), nullable=True),
+ sa.Column('try_number', sa.Integer(), nullable=True),
+ sa.Column('start_date', timestamp(), nullable=True),
+ sa.Column('operator', sa.String(length=1000), nullable=False),
+ sa.Column('op_classpath', sa.String(length=1000), nullable=False),
+ sa.Column('hashcode', sa.BigInteger(), nullable=False),
+ sa.Column('shardcode', sa.Integer(), nullable=False),
+ sa.Column('poke_context', sa.Text(), nullable=False),
+ sa.Column('execution_context', sa.Text(), nullable=True),
+ sa.Column('created_at', timestamp(), default=func.now(),
nullable=False),
+ sa.Column('updated_at', timestamp(), default=func.now(),
nullable=False),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_index(
+ 'ti_primary_key',
+ 'sensor_instance',
+ ['dag_id', 'task_id', 'execution_date'],
+ unique=True
+ )
+ op.create_index('si_hashcode', 'sensor_instance', ['hashcode'],
unique=False)
+ op.create_index('si_shardcode', 'sensor_instance', ['shardcode'],
unique=False)
+ op.create_index('si_state_shard', 'sensor_instance', ['state',
'shardcode'], unique=False)
+ op.create_index('si_updated_at', 'sensor_instance', ['updated_at'],
unique=False)
+
+
+def downgrade(): # noqa: D103
+ op.drop_table('sensor_instance')
diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py
index 94a4b2b..4a53845 100644
--- a/airflow/models/__init__.py
+++ b/airflow/models/__init__.py
@@ -27,6 +27,7 @@ from airflow.models.errors import ImportError # pylint:
disable=redefined-built
from airflow.models.log import Log
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
+from airflow.models.sensorinstance import SensorInstance # noqa: F401
from airflow.models.skipmixin import SkipMixin
from airflow.models.slamiss import SlaMiss
from airflow.models.taskfail import TaskFail
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 352ebcd..27013f2 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1334,6 +1334,13 @@ class BaseOperator(Operator, LoggingMixin,
metaclass=BaseOperatorMeta):
return cls.__serialized_fields
+ def is_smart_sensor_compatible(self):
+ """
+ Return if this operator can use smart service. Default False.
+
+ """
+ return False
+
def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
r"""
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index a9def70..4acc4c7 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -70,6 +70,9 @@ class DagBag(BaseDagBag, LoggingMixin):
:param include_examples: whether to include the examples that ship
with airflow or not
:type include_examples: bool
+ :param include_smart_sensor: whether to include the smart sensor native
+ DAGs that create the smart sensor operators for whole cluster
+ :type include_smart_sensor: bool
:param read_dags_from_db: Read DAGs from DB if store_serialized_dags is
``True``.
If ``False`` DAGs are read from python files. This property is not
used when
determining whether or not to write Serialized DAGs, that is done by
checking
@@ -84,6 +87,7 @@ class DagBag(BaseDagBag, LoggingMixin):
self,
dag_folder: Optional[str] = None,
include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
+ include_smart_sensor: bool = conf.getboolean('smart_sensor',
'USE_SMART_SENSOR'),
safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
read_dags_from_db: bool = False,
store_serialized_dags: Optional[bool] = None,
@@ -113,6 +117,7 @@ class DagBag(BaseDagBag, LoggingMixin):
self.collect_dags(
dag_folder=dag_folder,
include_examples=include_examples,
+ include_smart_sensor=include_smart_sensor,
safe_mode=safe_mode)
def size(self) -> int:
@@ -391,6 +396,7 @@ class DagBag(BaseDagBag, LoggingMixin):
dag_folder=None,
only_if_updated=True,
include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
+ include_smart_sensor=conf.getboolean('smart_sensor',
'USE_SMART_SENSOR'),
safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')):
"""
Given a file path or a folder, this method looks for python modules,
@@ -414,8 +420,10 @@ class DagBag(BaseDagBag, LoggingMixin):
stats = []
dag_folder = correct_maybe_zipped(dag_folder)
- for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode,
- include_examples=include_examples):
+ for filepath in list_py_file_paths(dag_folder,
+ safe_mode=safe_mode,
+ include_examples=include_examples,
+
include_smart_sensor=include_smart_sensor):
try:
file_parse_start_dttm = timezone.utcnow()
found_dags = self.process_file(
diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py
new file mode 100644
index 0000000..88aba4f
--- /dev/null
+++ b/airflow/models/sensorinstance.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+
+from sqlalchemy import BigInteger, Column, Index, Integer, String, Text
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.models.base import ID_LEN, Base
+from airflow.utils import timezone
+from airflow.utils.session import provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+from airflow.utils.state import State
+
+
+class SensorInstance(Base):
+ """
+ SensorInstance support the smart sensor service. It stores the sensor task
states
+ and context that required for poking include poke context and execution
context.
+ In sensor_instance table we also save the sensor operator classpath so
that inside
+ smart sensor there is no need to import the dagbag and create task object
for each
+ sensor task.
+
+ SensorInstance include another set of columns to support the smart sensor
shard on
+ large number of sensor instance. The key idea is to generate the hash code
from the
+ poke context and use it to map to a shorter shard code which can be used
as an index.
+ Every smart sensor process takes care of tasks whose `shardcode` are in a
certain range.
+
+ """
+
+ __tablename__ = "sensor_instance"
+
+ id = Column(Integer, primary_key=True)
+ task_id = Column(String(ID_LEN), nullable=False)
+ dag_id = Column(String(ID_LEN), nullable=False)
+ execution_date = Column(UtcDateTime, nullable=False)
+ state = Column(String(20))
+ _try_number = Column('try_number', Integer, default=0)
+ start_date = Column(UtcDateTime)
+ operator = Column(String(1000), nullable=False)
+ op_classpath = Column(String(1000), nullable=False)
+ hashcode = Column(BigInteger, nullable=False)
+ shardcode = Column(Integer, nullable=False)
+ poke_context = Column(Text, nullable=False)
+ execution_context = Column(Text)
+ created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False)
+ updated_at = Column(UtcDateTime,
+ default=timezone.utcnow(),
+ onupdate=timezone.utcnow(),
+ nullable=False)
+
+ __table_args__ = (
+ Index('ti_primary_key', dag_id, task_id, execution_date, unique=True),
+
+ Index('si_hashcode', hashcode),
+ Index('si_shardcode', shardcode),
+ Index('si_state_shard', state, shardcode),
+ Index('si_updated_at', updated_at),
+ )
+
+ def __init__(self, ti):
+ self.dag_id = ti.dag_id
+ self.task_id = ti.task_id
+ self.execution_date = ti.execution_date
+
+ @staticmethod
+ def get_classpath(obj):
+ """
+ Get the object dotted class path. Used for getting operator classpath.
+
+ :param obj:
+ :type obj:
+ :return: The class path of input object
+ :rtype: str
+ """
+ module_name, class_name = obj.__module__, obj.__class__.__name__
+
+ return module_name + "." + class_name
+
+ @classmethod
+ @provide_session
+ def register(cls, ti, poke_context, execution_context, session=None):
+ """
+ Register task instance ti for a sensor in sensor_instance table.
Persist the
+ context used for a sensor and set the sensor_instance table state to
sensing.
+
+ :param ti: The task instance for the sensor to be registered.
+ :type: ti:
+ :param poke_context: Context used for sensor poke function.
+ :type poke_context: dict
+ :param execution_context: Context used for execute sensor such as
timeout
+ setting and email configuration.
+ :type execution_context: dict
+ :param session: SQLAlchemy ORM Session
+ :type session: Session
+ :return: True if the ti was registered successfully.
+ :rtype: Boolean
+ """
+ if poke_context is None:
+ raise AirflowException('poke_context should not be None')
+
+ encoded_poke = json.dumps(poke_context)
+ encoded_execution_context = json.dumps(execution_context)
+
+ sensor = session.query(SensorInstance).filter(
+ SensorInstance.dag_id == ti.dag_id,
+ SensorInstance.task_id == ti.task_id,
+ SensorInstance.execution_date == ti.execution_date
+ ).with_for_update().first()
+
+ if sensor is None:
+ sensor = SensorInstance(ti=ti)
+
+ sensor.operator = ti.operator
+ sensor.op_classpath = SensorInstance.get_classpath(ti.task)
+ sensor.poke_context = encoded_poke
+ sensor.execution_context = encoded_execution_context
+
+ sensor.hashcode = hash(encoded_poke)
+ sensor.shardcode = sensor.hashcode % conf.getint('smart_sensor',
'shard_code_upper_limit')
+ sensor.try_number = ti.try_number
+
+ sensor.state = State.SENSING
+ sensor.start_date = timezone.utcnow()
+ session.add(sensor)
+ session.commit()
+
+ return True
+
+ @property
+ def try_number(self):
+ """
+ Return the try number that this task number will be when it is actually
+ run.
+ If the TI is currently running, this will match the column in the
+ database, in all other cases this will be incremented.
+ """
+ # This is designed so that task logs end up in the right file.
+ if self.state in State.running():
+ return self._try_number
+ return self._try_number + 1
+
+ @try_number.setter
+ def try_number(self, value):
+ self._try_number = value
+
+ def __repr__(self):
+ return "<{self.__class__.__name__}: id: {self.id} poke_context:
{self.poke_context} " \
+ "execution_context: {self.execution_context} state:
{self.state}>".format(
+ self=self)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index cb5a1bc..b3220ae 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -29,6 +29,7 @@ from typing import Any, Dict, Iterable, List, NamedTuple,
Optional, Tuple, Union
from urllib.parse import quote
import dill
+import jinja2
import lazy_object_proxy
import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
@@ -41,7 +42,7 @@ from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException, AirflowFailException, AirflowRescheduleException,
AirflowSkipException,
- AirflowTaskTimeout,
+ AirflowSmartSensorException, AirflowTaskTimeout,
)
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.models.log import Log
@@ -281,7 +282,8 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
- if self.state == State.RUNNING:
+ # TODO: whether we need sensing here or not (in sensor and
task_instance state machine)
+ if self.state in State.running():
return self._try_number
return self._try_number + 1
@@ -1072,6 +1074,9 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
self._prepare_and_execute_task_with_callbacks(context, task)
self.refresh_from_db(lock_for_update=True)
self.state = State.SUCCESS
+ except AirflowSmartSensorException as e:
+ self.log.info(e)
+ return
except AirflowSkipException as e:
# Recording SKIP
# log only if exception has any arguments to prevent log flooding
@@ -1172,6 +1177,20 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
# Run on_execute callback
self._run_execute_callback(context, task)
+ if task_copy.is_smart_sensor_compatible():
+ # Try to register it in the smart sensor service.
+ registered = False
+ try:
+ registered = task_copy.register_in_sensor_service(self,
context)
+ except Exception as e:
+ self.log.warning("Failed to register in sensor service."
+ "Continue to run task in non smart sensor
mode.")
+ self.log.exception(e, exc_info=True)
+
+ if registered:
+ # Will raise AirflowSmartSensorException to avoid long running
execution.
+ self._update_ti_state_for_sensing()
+
# Execute the task
with set_current_context(context):
result = self._execute_task(context, task_copy)
@@ -1187,6 +1206,16 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
Stats.incr('operator_successes_{}'.format(self.task.__class__.__name__), 1, 1)
Stats.incr('ti_successes')
+ @provide_session
+ def _update_ti_state_for_sensing(self, session=None):
+ self.log.info('Submitting %s to sensor service', self)
+ self.state = State.SENSING
+ self.start_date = timezone.utcnow()
+ session.merge(self)
+ session.commit()
+ # Raise exception for sensing state
+ raise AirflowSmartSensorException("Task successfully registered in
smart sensor.")
+
def _run_success_callback(self, context, task):
"""Functions that need to be run if Task is successful"""
# Success callback
@@ -1580,19 +1609,12 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
self.task.render_template_fields(context)
- def email_alert(self, exception):
- """Send Email Alert with exception trace"""
+ def get_email_subject_content(self, exception):
+ """Get the email subject content for exceptions."""
+ # For a ti from DB (without ti.task), return the default value
+ # Reuse it for smart sensor to send default email alert
+ use_default = not hasattr(self, 'task')
exception_html = str(exception).replace('\n', '<br>')
- jinja_context = self.get_template_context()
- # This function is called after changing the state
- # from State.RUNNING so use prev_attempted_tries.
- jinja_context.update(dict(
- exception=exception,
- exception_html=exception_html,
- try_number=self.prev_attempted_tries,
- max_tries=self.max_tries))
-
- jinja_env = self.task.get_template_env()
default_subject = 'Airflow alert: {{ti}}'
# For reporting purposes, we report based on 1-indexed,
@@ -1607,29 +1629,62 @@ class TaskInstance(Base, LoggingMixin): # pylint:
disable=R0902,R0904
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)
- def render(key, content):
- if conf.has_option('email', key):
- path = conf.get('email', key)
- with open(path) as file:
- content = file.read()
+ default_html_content_err = (
+ 'Try {{try_number}} out of {{max_tries + 1}}<br>'
+ 'Exception:<br>Failed attempt to attach error logs<br>'
+ 'Log: <a href="{{ti.log_url}}">Link</a><br>'
+ 'Host: {{ti.hostname}}<br>'
+ 'Log file: {{ti.log_filepath}}<br>'
+ 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
+ )
+
+ if use_default:
+ jinja_context = {'ti': self}
+ # This function is called after changing the state
+ # from State.RUNNING so need to subtract 1 from self.try_number.
+ jinja_context.update(dict(
+ exception=exception,
+ exception_html=exception_html,
+ try_number=self.try_number - 1,
+ max_tries=self.max_tries))
+
+ jinja_env = jinja2.Environment(
+ loader=jinja2.FileSystemLoader(os.path.dirname(__file__)),
+ autoescape=True)
+ subject =
jinja_env.from_string(default_subject).render(**jinja_context)
+ html_content =
jinja_env.from_string(default_html_content).render(**jinja_context)
+ html_content_err =
jinja_env.from_string(default_html_content_err).render(**jinja_context)
+
+ else:
+ jinja_context = self.get_template_context()
+
+ jinja_context.update(dict(
+ exception=exception,
+ exception_html=exception_html,
+ try_number=self.try_number - 1,
+ max_tries=self.max_tries))
- return jinja_env.from_string(content).render(**jinja_context)
+ jinja_env = self.task.get_template_env()
- subject = render('subject_template', default_subject)
- html_content = render('html_content_template', default_html_content)
- # noinspection PyBroadException
+ def render(key, content):
+ if conf.has_option('email', key):
+ path = conf.get('email', key)
+ with open(path) as f:
+ content = f.read()
+ return jinja_env.from_string(content).render(**jinja_context)
+
+ subject = render('subject_template', default_subject)
+ html_content = render('html_content_template',
default_html_content)
+ html_content_err = render('html_content_template',
default_html_content_err)
+
+ return subject, html_content, html_content_err
+
+ def email_alert(self, exception):
+ """Send alert email with exception information."""
+ subject, html_content, html_content_err =
self.get_email_subject_content(exception)
try:
send_email(self.task.email, subject, html_content)
- except Exception: # pylint: disable=broad-except
- default_html_content_err = (
- 'Try {{try_number}} out of {{max_tries + 1}}<br>'
- 'Exception:<br>Failed attempt to attach error logs<br>'
- 'Log: <a href="{{ti.log_url}}">Link</a><br>'
- 'Host: {{ti.hostname}}<br>'
- 'Log file: {{ti.log_filepath}}<br>'
- 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
- )
- html_content_err = render('html_content_template',
default_html_content_err)
+ except Exception:
send_email(self.task.email, subject, html_content_err)
def set_duration(self) -> None:
diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py
b/airflow/providers/apache/hive/sensors/metastore_partition.py
index 31376ad..0955291 100644
--- a/airflow/providers/apache/hive/sensors/metastore_partition.py
+++ b/airflow/providers/apache/hive/sensors/metastore_partition.py
@@ -44,6 +44,7 @@ class MetastorePartitionSensor(SqlSensor):
template_fields = ('partition_name', 'table', 'schema')
ui_color = '#8da7be'
+ poke_context_fields = ('partition_name', 'table', 'schema',
'mysql_conn_id')
@apply_defaults
def __init__(
diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py
b/airflow/providers/apache/hive/sensors/named_hive_partition.py
index 23d9466..b8cb600 100644
--- a/airflow/providers/apache/hive/sensors/named_hive_partition.py
+++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py
@@ -40,6 +40,7 @@ class NamedHivePartitionSensor(BaseSensorOperator):
template_fields = ('partition_names',)
ui_color = '#8d99ae'
+ poke_context_fields = ('partition_names', 'metastore_conn_id')
@apply_defaults
def __init__(
@@ -104,3 +105,12 @@ class NamedHivePartitionSensor(BaseSensorOperator):
self.next_index_to_poke = 0
return True
+
+ def is_smart_sensor_compatible(self):
+ result = (
+ not self.soft_fail
+ and not self.hook
+ and len(self.partition_names) <= 30
+ and super().is_smart_sensor_compatible()
+ )
+ return result
diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py
b/airflow/providers/elasticsearch/log/es_task_handler.py
index 93f9a1b..d3d07c1 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -18,9 +18,10 @@
import logging
import sys
+from collections import defaultdict
from datetime import datetime
from time import time
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
from urllib.parse import quote
# Using `from elasticsearch import *` would break elasticsearch mocking used
in unit test.
@@ -36,6 +37,9 @@ from airflow.utils.log.file_task_handler import
FileTaskHandler
from airflow.utils.log.json_formatter import JSONFormatter
from airflow.utils.log.logging_mixin import LoggingMixin
+# Elasticsearch hosted log type
+EsLogMsgType = List[Tuple[str, str]]
+
class ElasticsearchTaskHandler(FileTaskHandler, LoggingMixin):
"""
@@ -118,7 +122,24 @@ class ElasticsearchTaskHandler(FileTaskHandler,
LoggingMixin):
"""
return execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f")
- def _read(self, ti: TaskInstance, try_number: int, metadata:
Optional[dict] = None) -> Tuple[str, dict]:
+ @staticmethod
+ def _group_logs_by_host(logs):
+ grouped_logs = defaultdict(list)
+ for log in logs:
+ key = getattr(log, 'host', 'default_host')
+ grouped_logs[key].append(log)
+
+ # return items sorted by timestamp.
+ result = sorted(grouped_logs.items(), key=lambda kv: getattr(kv[1][0],
'message', '_'))
+
+ return result
+
+ def _read_grouped_logs(self):
+ return True
+
+ def _read(
+ self, ti: TaskInstance, try_number: int, metadata: Optional[dict] =
None
+ ) -> Tuple[EsLogMsgType, dict]:
"""
Endpoint for streaming log.
@@ -126,7 +147,7 @@ class ElasticsearchTaskHandler(FileTaskHandler,
LoggingMixin):
:param try_number: try_number of the task instance
:param metadata: log metadata,
can be used for steaming log reading and auto-tailing.
- :return: a list of log documents and metadata.
+ :return: a list of tuple with host and log documents, metadata.
"""
if not metadata:
metadata = {'offset': 0}
@@ -137,6 +158,7 @@ class ElasticsearchTaskHandler(FileTaskHandler,
LoggingMixin):
log_id = self._render_log_id(ti, try_number)
logs = self.es_read(log_id, offset, metadata)
+ logs_by_host = self._group_logs_by_host(logs)
next_offset = offset if not logs else logs[-1].offset
@@ -147,7 +169,10 @@ class ElasticsearchTaskHandler(FileTaskHandler,
LoggingMixin):
# end_of_log_mark may contain characters like '\n' which is needed to
# have the log uploaded but will not be stored in elasticsearch.
- metadata['end_of_log'] = False if not logs else logs[-1].message ==
self.end_of_log_mark.strip()
+ loading_hosts = [
+ item[0] for item in logs_by_host if item[-1][-1].message !=
self.end_of_log_mark.strip()
+ ]
+ metadata['end_of_log'] = False if not logs else len(loading_hosts) == 0
cur_ts = pendulum.now()
# Assume end of log after not receiving new log for 5 min,
@@ -167,8 +192,11 @@ class ElasticsearchTaskHandler(FileTaskHandler,
LoggingMixin):
# If we hit the end of the log, remove the actual end_of_log message
# to prevent it from showing in the UI.
- i = len(logs) if not metadata['end_of_log'] else len(logs) - 1
- message = '\n'.join([log.message for log in logs[0:i]])
+ def concat_logs(lines):
+ log_range = (len(lines) - 1) if lines[-1].message ==
self.end_of_log_mark.strip() else len(lines)
+ return '\n'.join([lines[i].message for i in range(log_range)])
+
+ message = [(host, concat_logs(hosted_log)) for host, hosted_log in
logs_by_host]
return message, metadata
diff --git a/airflow/sensors/base_sensor_operator.py
b/airflow/sensors/base_sensor_operator.py
index c4bfd43..528c04c 100644
--- a/airflow/sensors/base_sensor_operator.py
+++ b/airflow/sensors/base_sensor_operator.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
+import datetime
import hashlib
import os
from datetime import timedelta
@@ -26,7 +27,7 @@ from airflow.configuration import conf
from airflow.exceptions import (
AirflowException, AirflowRescheduleException, AirflowSensorTimeout,
AirflowSkipException,
)
-from airflow.models import BaseOperator
+from airflow.models import BaseOperator, SensorInstance
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
@@ -68,6 +69,15 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
ui_color = '#e6f1f2' # type: str
valid_modes = ['poke', 'reschedule'] # type: Iterable[str]
+ # As the poke context in smart sensor defines the poking job signature
only,
+ # The execution_fields defines other execution details
+ # for this tasks such as the customer defined timeout, the email and the
alert
+ # setup. Smart sensor serialize these attributes into a different DB
column so
+ # that smart sensor service is able to handle corresponding execution
details
+ # without breaking the sensor poking logic with dedup.
+ execution_fields = ('poke_interval', 'retries', 'execution_timeout',
'timeout',
+ 'email', 'email_on_retry', 'email_on_failure',)
+
@apply_defaults
def __init__(self, *,
poke_interval: float = 60,
@@ -83,6 +93,9 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
self.mode = mode
self.exponential_backoff = exponential_backoff
self._validate_input_values()
+ self.sensor_service_enabled = conf.getboolean('smart_sensor',
'use_smart_sensor')
+ self.sensors_support_sensor_service = set(
+ map(lambda l: l.strip(), conf.get('smart_sensor',
'sensors_enabled').split(',')))
def _validate_input_values(self) -> None:
if not isinstance(self.poke_interval, (int, float)) or
self.poke_interval < 0:
@@ -106,6 +119,65 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
"""
raise AirflowException('Override me.')
+ def is_smart_sensor_compatible(self):
+ check_list = [not self.sensor_service_enabled,
+ self.on_success_callback,
+ self.on_retry_callback,
+ self.on_failure_callback]
+ for status in check_list:
+ if status:
+ return False
+
+ operator = self.__class__.__name__
+ return operator in self.sensors_support_sensor_service
+
+ def register_in_sensor_service(self, ti, context):
+ """
+ Register ti in smart sensor service
+
+ :param ti: Task instance object.
+ :param context: TaskInstance template context from the ti.
+ :return: boolean
+ """
+ poke_context = self.get_poke_context(context)
+ execution_context = self.get_execution_context(context)
+
+ return SensorInstance.register(ti, poke_context, execution_context)
+
+ def get_poke_context(self, context):
+ """
+ Return a dictionary with all attributes in poke_context_fields. The
+ poke_context with operator class can be used to identify a unique
+ sensor job.
+
+ :param context: TaskInstance template context.
+ :return: A dictionary with key in poke_context_fields.
+ """
+ if not context:
+ self.log.info("Function get_poke_context doesn't have a context
input.")
+
+ poke_context_fields = getattr(self.__class__, "poke_context_fields",
None)
+ result = {key: getattr(self, key, None) for key in poke_context_fields}
+ return result
+
+ def get_execution_context(self, context):
+ """
+ Return a dictionary with all attributes in execution_fields. The
+ execution_context include execution requirement for each sensor task
+ such as timeout setup, email_alert setup.
+
+ :param context: TaskInstance template context.
+ :return: A dictionary with key in execution_fields.
+ """
+ if not context:
+ self.log.info("Function get_execution_context doesn't have a
context input.")
+ execution_fields = self.__class__.execution_fields
+
+ result = {key: getattr(self, key, None) for key in execution_fields}
+ if result['execution_timeout'] and
isinstance(result['execution_timeout'], datetime.timedelta):
+ result['execution_timeout'] =
result['execution_timeout'].total_seconds()
+ return result
+
def execute(self, context: Dict) -> Any:
started_at = timezone.utcnow()
try_number = 1
diff --git a/airflow/sensors/smart_sensor_operator.py
b/airflow/sensors/smart_sensor_operator.py
new file mode 100644
index 0000000..d6b6df2
--- /dev/null
+++ b/airflow/sensors/smart_sensor_operator.py
@@ -0,0 +1,764 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import datetime
+import json
+import logging
+import time
+import traceback
+from logging.config import DictConfigurator # type: ignore
+from time import sleep
+
+from sqlalchemy import and_, or_, tuple_
+
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
+from airflow.models import BaseOperator, SensorInstance, SkipMixin,
TaskInstance
+from airflow.settings import LOGGING_CLASS_PATH
+from airflow.stats import Stats
+from airflow.utils import helpers, timezone
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.email import send_email
+from airflow.utils.log.logging_mixin import set_context
+from airflow.utils.module_loading import import_string
+from airflow.utils.net import get_hostname
+from airflow.utils.session import provide_session
+from airflow.utils.state import PokeState, State
+from airflow.utils.timeout import timeout
+
+config = import_string(LOGGING_CLASS_PATH)
+handler_config = config['handlers']['task']
+try:
+ formatter_config = config['formatters'][handler_config['formatter']]
+except Exception as err: # pylint: disable=broad-except
+ formatter_config = None
+ print(err)
+dictConfigurator = DictConfigurator(config)
+
+
+class SensorWork:
+ """
+ This class stores a sensor work with decoded context value. It is only used
+ inside of smart sensor. Create a sensor work based on sensor instance
record.
+ A sensor work object has the following attributes:
+ `dag_id`: sensor_instance dag_id.
+ `task_id`: sensor_instance task_id.
+ `execution_date`: sensor_instance execution_date.
+ `try_number`: sensor_instance try_number
+ `poke_context`: Decoded poke_context for the sensor task.
+ `execution_context`: Decoded execution_context.
+ `hashcode`: This is the signature of poking job.
+ `operator`: The sensor operator class.
+ `op_classpath`: The sensor operator class path
+ `encoded_poke_context`: The raw data from sensor_instance poke_context
column.
+ `log`: The sensor work logger which will mock the corresponding task
instance log.
+
+ :param si: The sensor_instance ORM object.
+ """
+ def __init__(self, si):
+ self.dag_id = si.dag_id
+ self.task_id = si.task_id
+ self.execution_date = si.execution_date
+ self.try_number = si.try_number
+
+ self.poke_context = json.loads(si.poke_context) if si.poke_context
else {}
+ self.execution_context = json.loads(si.execution_context) if
si.execution_context else {}
+ try:
+ self.log = self._get_sensor_logger(si)
+ except Exception as e: # pylint: disable=broad-except
+ self.log = None
+ print(e)
+ self.hashcode = si.hashcode
+ self.start_date = si.start_date
+ self.operator = si.operator
+ self.op_classpath = si.op_classpath
+ self.encoded_poke_context = si.poke_context
+
+ def __eq__(self, other):
+ if not isinstance(other, SensorWork):
+ return NotImplemented
+
+ return self.dag_id == other.dag_id and \
+ self.task_id == other.task_id and \
+ self.execution_date == other.execution_date and \
+ self.try_number == other.try_number
+
+ @staticmethod
+ def create_new_task_handler():
+ """
+ Create task log handler for a sensor work.
+ :return: log handler
+ """
+ handler_config_copy = {k: handler_config[k] for k in handler_config}
+ formatter_config_copy = {k: formatter_config[k] for k in
formatter_config}
+ handler = dictConfigurator.configure_handler(handler_config_copy)
+ formatter = dictConfigurator.configure_formatter(formatter_config_copy)
+ handler.setFormatter(formatter)
+ return handler
+
+ def _get_sensor_logger(self, si):
+ """
+ Return logger for a sensor instance object.
+ """
+ # The created log_id is used inside of smart sensor as the key to fetch
+ # the corresponding in memory log handler.
+ si.raw = False # Otherwise set_context will fail
+ log_id = "-".join([si.dag_id,
+ si.task_id,
+ si.execution_date.strftime("%Y_%m_%dT%H_%M_%S_%f"),
+ str(si.try_number)])
+ logger = logging.getLogger('airflow.task' + '.' + log_id)
+
+ if len(logger.handlers) == 0:
+ handler = self.create_new_task_handler()
+ logger.addHandler(handler)
+ set_context(logger, si)
+
+ line_break = ("-" * 120)
+ logger.info(line_break)
+ logger.info("Processing sensor task %s in smart sensor service on
host: %s",
+ self.ti_key, get_hostname())
+ logger.info(line_break)
+ return logger
+
+ def close_sensor_logger(self):
+ """
+ Close log handler for a sensor work.
+ """
+ for handler in self.log.handlers:
+ try:
+ handler.close()
+ except Exception as e: # pylint: disable=broad-except
+ print(e)
+
+ @property
+ def ti_key(self):
+ """
+ Key for the task instance that maps to the sensor work.
+ """
+ return self.dag_id, self.task_id, self.execution_date
+
+ @property
+ def cache_key(self):
+ """
+ Key used to query in smart sensor for cached sensor work.
+ """
+ return self.operator, self.encoded_poke_context
+
+
+class CachedPokeWork:
+ """
+ Wrapper class for the poke work inside smart sensor. It saves
+ the sensor_task used to poke and recent poke result state.
+ state: poke state.
+ sensor_task: The cached object for executing the poke function.
+ last_poke_time: The latest time this cached work being called.
+ to_flush: If we should flush the cached work.
+ """
+ def __init__(self):
+ self.state = None
+ self.sensor_task = None
+ self.last_poke_time = None
+ self.to_flush = False
+
+ def set_state(self, state):
+ """
+ Set state for cached poke work.
+ :param state: The sensor_instance state.
+ """
+ self.state = state
+ self.last_poke_time = timezone.utcnow()
+
+ def clear_state(self):
+ """
+ Clear state for cached poke work.
+ """
+ self.state = None
+
+ def set_to_flush(self):
+ """
+ Mark this poke work to be popped from cached dict after current loop.
+ """
+ self.to_flush = True
+
+ def is_expired(self):
+ """
+ The cached task object expires if there is no poke for 20 minutes.
+ :return: Boolean
+ """
+ return self.to_flush or (timezone.utcnow() -
self.last_poke_time).total_seconds() > 1200
+
+
+class SensorExceptionInfo:
+ """
+ Hold sensor exception information and the type of exception. For possible
transient
+ infra failure, give the task more chance to retry before fail it.
+ """
+ def __init__(self,
+ exception_info,
+ is_infra_failure=False,
+ infra_failure_retry_window=datetime.timedelta(minutes=130)):
+ self._exception_info = exception_info
+ self._is_infra_failure = is_infra_failure
+ self._infra_failure_retry_window = infra_failure_retry_window
+
+ self._infra_failure_timeout = None
+ self.set_infra_failure_timeout()
+ self.fail_current_run = self.should_fail_current_run()
+
+ def set_latest_exception(self, exception_info, is_infra_failure=False):
+ """
+ This function set the latest exception information for sensor
exception. If the exception
+ implies an infra failure, this function will check the recorded infra
failure timeout
+ which was set at the first infra failure exception arrives. There is a
6 hours window
+ for retry without failing current run.
+
+ :param exception_info: Details of the exception information.
+ :param is_infra_failure: If current exception was caused by transient
infra failure.
+ There is a retry window _infra_failure_retry_window that the smart
sensor will
+ retry poke function without failing current task run.
+ """
+ self._exception_info = exception_info
+ self._is_infra_failure = is_infra_failure
+
+ self.set_infra_failure_timeout()
+ self.fail_current_run = self.should_fail_current_run()
+
+ def set_infra_failure_timeout(self):
+ """
+ Set the time point when the sensor should be failed if it kept getting
infra
+ failure.
+ :return:
+ """
+ # Only set the infra_failure_timeout if there is no existing one
+ if not self._is_infra_failure:
+ self._infra_failure_timeout = None
+ elif self._infra_failure_timeout is None:
+ self._infra_failure_timeout = timezone.utcnow() +
self._infra_failure_retry_window
+
+ def should_fail_current_run(self):
+ """
+ :return: Should the sensor fail
+ :type: boolean
+ """
+ return not self.is_infra_failure or timezone.utcnow() >
self._infra_failure_timeout
+
+ @property
+ def exception_info(self):
+ """
+ :return: exception msg.
+ """
+ return self._exception_info
+
+ @property
+ def is_infra_failure(self):
+ """
+
+ :return: If the exception is an infra failure
+ :type: boolean
+ """
+ return self._is_infra_failure
+
+ def is_expired(self):
+ """
+ :return: If current exception need to be kept.
+ :type: boolean
+ """
+ if not self._is_infra_failure:
+ return True
+ return timezone.utcnow() > self._infra_failure_timeout +
datetime.timedelta(minutes=30)
+
+
+class SmartSensorOperator(BaseOperator, SkipMixin):
+ """
+ Smart sensor operators are derived from this class.
+
+ Smart Sensor operators keep refresh a dictionary by visiting DB.
+ Taking qualified active sensor tasks. Different from sensor operator,
+ Smart sensor operators poke for all sensor tasks in the dictionary at
+ a time interval. When a criteria is met or fail by time out, it update
+ all sensor task state in task_instance table
+
+ :param soft_fail: Set to true to mark the task as SKIPPED on failure
+ :type soft_fail: bool
+ :param poke_interval: Time in seconds that the job should wait in
+ between each tries.
+ :type poke_interval: int
+ :param smart_sensor_timeout: Time, in seconds before the internal sensor
+ job times out if poke_timeout is not defined.
+ :type smart_sensor_timeout: int
+ :param shard_min: shard code lower bound (inclusive)
+ :type shard_min: int
+ :param shard_max: shard code upper bound (exclusive)
+ :type shard_max: int
+ :param poke_exception_cache_ttl: Time, in seconds before the current
+ exception expires and being cleaned up.
+ :type poke_exception_cache_ttl: int
+ :param poke_timeout: Time, in seconds before the task times out and fails.
+ :type poke_timeout: int
+ """
+ ui_color = '#e6f1f2'
+
+ @apply_defaults
+ def __init__(self,
+ poke_interval=180,
+ smart_sensor_timeout=60 * 60 * 24 * 7,
+ soft_fail=False,
+ shard_min=0,
+ shard_max=100000,
+ poke_exception_cache_ttl=600,
+ poke_timeout=6,
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ # super(SmartSensorOperator, self).__init__(*args, **kwargs)
+ self.poke_interval = poke_interval
+ self.soft_fail = soft_fail
+ self.timeout = smart_sensor_timeout
+ self._validate_input_values()
+ self.hostname = ""
+
+ self.sensor_works = []
+ self.cached_dedup_works = {}
+ self.cached_sensor_exceptions = {}
+
+ self.max_tis_per_query = 50
+ self.shard_min = shard_min
+ self.shard_max = shard_max
+ self.poke_exception_cache_ttl = poke_exception_cache_ttl
+ self.poke_timeout = poke_timeout
+
+ def _validate_input_values(self):
+ if not isinstance(self.poke_interval, (int, float)) or
self.poke_interval < 0:
+ raise AirflowException(
+ "The poke_interval must be a non-negative number")
+ if not isinstance(self.timeout, (int, float)) or self.timeout < 0:
+ raise AirflowException(
+ "The timeout must be a non-negative number")
+
+ @provide_session
+ def _load_sensor_works(self, session=None):
+ """
+ Refresh sensor instances need to be handled by this operator. Create
smart sensor
+ internal object based on the information persisted in the
sensor_instance table.
+
+ """
+ SI = SensorInstance
+ start_query_time = time.time()
+ query = session.query(SI) \
+ .filter(SI.state == State.SENSING)\
+ .filter(SI.shardcode < self.shard_max,
+ SI.shardcode >= self.shard_min)
+ tis = query.all()
+
+ self.log.info("Performance query %s tis, time: %s", len(tis),
time.time() - start_query_time)
+
+ # Query without checking dagrun state might keep some failed dag_run
tasks alive.
+ # Join with DagRun table will be very slow based on the number of
sensor tasks we
+ # need to handle. We query all smart tasks in this operator
+ # and expect scheduler correct the states in
_change_state_for_tis_without_dagrun()
+
+ sensor_works = []
+ for ti in tis:
+ try:
+ sensor_works.append(SensorWork(ti))
+ except Exception as e: # pylint: disable=broad-except
+ self.log.exception("Exception at creating sensor work for ti
%s", ti.key)
+ self.log.exception(e, exc_info=True)
+
+ self.log.info("%d tasks detected.", len(sensor_works))
+
+ new_sensor_works = [x for x in sensor_works if x not in
self.sensor_works]
+
+ self._update_ti_hostname(new_sensor_works)
+
+ self.sensor_works = sensor_works
+
+ @provide_session
+ def _update_ti_hostname(self, sensor_works, session=None):
+ """
+ Update task instance hostname for new sensor works.
+
+ :param sensor_works: Smart sensor internal object for a sensor task.
+ :param session: The sqlalchemy session.
+ """
+ TI = TaskInstance
+ ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in
sensor_works]
+
+ def update_ti_hostname_with_count(count, ti_keys):
+ # Using or_ instead of in_ here to prevent from full table scan.
+ tis = session.query(TI) \
+ .filter(or_(tuple_(TI.dag_id, TI.task_id, TI.execution_date)
== ti_key
+ for ti_key in ti_keys)) \
+ .all()
+
+ for ti in tis:
+ ti.hostname = self.hostname
+ session.commit()
+
+ return count + len(ti_keys)
+
+ count = helpers.reduce_in_chunks(update_ti_hostname_with_count,
ti_keys, 0, self.max_tis_per_query)
+ if count:
+ self.log.info("Updated hostname on %s tis.", count)
+
+ @provide_session
+ def _mark_multi_state(self, operator, poke_hash, encoded_poke_context,
state, session=None):
+ """
+ Mark state for multiple tasks in the task_instance table to a new
state if they have
+ the same signature as the poke_hash.
+
+ :param operator: The sensor's operator class name.
+ :param poke_hash: The hash code generated from sensor's poke context.
+ :param encoded_poke_context: The raw encoded poke_context.
+ :param state: Set multiple sensor tasks to this state.
+ :param session: The sqlalchemy session.
+ """
+ def mark_state(ti, sensor_instance):
+ ti.state = state
+ sensor_instance.state = state
+ if state in State.finished():
+ ti.end_date = end_date
+ ti.set_duration()
+
+ SI = SensorInstance
+ TI = TaskInstance
+
+ count_marked = 0
+ try:
+ query_result = session.query(TI, SI)\
+ .join(TI, and_(TI.dag_id == SI.dag_id,
+ TI.task_id == SI.task_id,
+ TI.execution_date == SI.execution_date)) \
+ .filter(SI.state == State.SENSING) \
+ .filter(SI.hashcode == poke_hash) \
+ .filter(SI.operator == operator) \
+ .with_for_update().all()
+
+ end_date = timezone.utcnow()
+ for ti, sensor_instance in query_result:
+ if sensor_instance.poke_context != encoded_poke_context:
+ continue
+
+ ti.hostname = self.hostname
+ if ti.state == State.SENSING:
+ mark_state(ti=ti, sensor_instance=sensor_instance)
+ count_marked += 1
+ else:
+ # ti.state != State.SENSING
+ sensor_instance.state = ti.state
+
+ session.commit()
+
+ except Exception as e: # pylint: disable=broad-except
+ self.log.warning("Exception _mark_multi_state in smart sensor for
hashcode %s",
+ str(poke_hash))
+ self.log.exception(e, exc_info=True)
+ self.log.info("Marked %s tasks out of %s to state %s", count_marked,
len(query_result), state)
+
+ @provide_session
+ def _retry_or_fail_task(self, sensor_work, error, session=None):
+ """
+ Change single task state for sensor task. For final state, set the
end_date.
+ Since smart sensor take care all retries in one process. Failed sensor
tasks
+ logically experienced all retries and the try_number should be set to
max_tries.
+
+ :param sensor_work: The sensor_work with exception.
+ :type sensor_work: SensorWork
+ :param error: The error message for this sensor_work.
+ :type error: str.
+ :param session: The sqlalchemy session.
+ """
+ def email_alert(task_instance, error_info):
+ try:
+ subject, html_content, _ =
task_instance.get_email_subject_content(error_info)
+ email = sensor_work.execution_context.get('email')
+
+ send_email(email, subject, html_content)
+ except Exception as e: # pylint: disable=broad-except
+ sensor_work.log.warning("Exception alerting email.")
+ sensor_work.log.exception(e, exc_info=True)
+
+ def handle_failure(sensor_work, ti):
+ if sensor_work.execution_context.get('retries', None) and \
+ ti.try_number <= ti.max_tries:
+ # retry
+ ti.state = State.UP_FOR_RETRY
+ if sensor_work.execution_context.get('email_on_retry', None)
and \
+ sensor_work.execution_context.get('email', None):
+ sensor_work.log.info("%s sending email alert for retry",
sensor_work.ti_key)
+ email_alert(ti, error)
+ else:
+ ti.state = State.FAILED
+ if sensor_work.execution_context.get('email_on_failure', None)
and \
+ sensor_work.execution_context.get('email', None):
+ sensor_work.log.info("%s sending email alert for failure",
sensor_work.ti_key)
+ email_alert(ti, error)
+
+ try:
+ dag_id, task_id, execution_date = sensor_work.ti_key
+ TI = TaskInstance
+ SI = SensorInstance
+ sensor_instance = session.query(SI).filter(
+ SI.dag_id == dag_id,
+ SI.task_id == task_id,
+ SI.execution_date == execution_date) \
+ .with_for_update() \
+ .first()
+
+ if sensor_instance.hashcode != sensor_work.hashcode:
+ # Return without setting state
+ return
+
+ ti = session.query(TI).filter(
+ TI.dag_id == dag_id,
+ TI.task_id == task_id,
+ TI.execution_date == execution_date) \
+ .with_for_update() \
+ .first()
+
+ if ti:
+ if ti.state == State.SENSING:
+ ti.hostname = self.hostname
+ handle_failure(sensor_work, ti)
+
+ sensor_instance.state = State.FAILED
+ ti.end_date = timezone.utcnow()
+ ti.set_duration()
+ else:
+ sensor_instance.state = ti.state
+ session.merge(sensor_instance)
+ session.merge(ti)
+ session.commit()
+
+ sensor_work.log.info("Task %s got an error: %s. Set the state
to failed. Exit.",
+ str(sensor_work.ti_key), error)
+ sensor_work.close_sensor_logger()
+
+ except AirflowException as e:
+ sensor_work.log.warning("Exception on failing %s",
sensor_work.ti_key)
+ sensor_work.log.exception(e, exc_info=True)
+
+ def _check_and_handle_ti_timeout(self, sensor_work):
+ """
+ Check if a sensor task in smart sensor is timeout. Could be either
sensor operator timeout
+ or general operator execution_timeout.
+
+ :param sensor_work: SensorWork
+ """
+ task_timeout = sensor_work.execution_context.get('timeout',
self.timeout)
+ task_execution_timeout =
sensor_work.execution_context.get('execution_timeout', None)
+ if task_execution_timeout:
+ task_timeout = min(task_timeout, task_execution_timeout)
+
+ if (timezone.utcnow() - sensor_work.start_date).total_seconds() >
task_timeout:
+ error = "Sensor Timeout"
+ sensor_work.log.exception(error)
+ self._retry_or_fail_task(sensor_work, error)
+
+ def _handle_poke_exception(self, sensor_work):
+ """
+ Fail task if accumulated exceptions exceeds retries.
+
+ :param sensor_work: SensorWork
+ """
+ sensor_exception =
self.cached_sensor_exceptions.get(sensor_work.cache_key)
+ error = sensor_exception.exception_info
+ sensor_work.log.exception("Handling poke exception: %s", error)
+
+ if sensor_exception.fail_current_run:
+ if sensor_exception.is_infra_failure:
+ sensor_work.log.exception("Task %s failed by infra failure in
smart sensor.",
+ sensor_work.ti_key)
+ # There is a risk for sensor object cached in smart sensor
keep throwing
+ # exception and cause an infra failure. To make sure the
sensor tasks after
+ # retry will not fall into same object and have endless infra
failure,
+ # we mark the sensor task after an infra failure so that it
can be popped
+ # before next poke loop.
+ cache_key = sensor_work.cache_key
+ self.cached_dedup_works[cache_key].set_to_flush()
+ else:
+ sensor_work.log.exception("Task %s failed by exceptions.",
sensor_work.ti_key)
+ self._retry_or_fail_task(sensor_work, error)
+ else:
+ sensor_work.log.info("Exception detected, retrying without failing
current run.")
+ self._check_and_handle_ti_timeout(sensor_work)
+
+ def _process_sensor_work_with_cached_state(self, sensor_work, state):
+ if state == PokeState.LANDED:
+ sensor_work.log.info("Task %s succeeded", str(sensor_work.ti_key))
+ sensor_work.close_sensor_logger()
+
+ if state == PokeState.NOT_LANDED:
+ # Handle timeout if connection valid but not landed yet
+ self._check_and_handle_ti_timeout(sensor_work)
+ elif state == PokeState.POKE_EXCEPTION:
+ self._handle_poke_exception(sensor_work)
+
+ def _execute_sensor_work(self, sensor_work):
+ ti_key = sensor_work.ti_key
+ log = sensor_work.log or self.log
+ log.info("Sensing ti: %s", str(ti_key))
+ log.info("Poking with arguments: %s", sensor_work.encoded_poke_context)
+
+ cache_key = sensor_work.cache_key
+ if cache_key not in self.cached_dedup_works:
+ # create an empty cached_work for a new cache_key
+ self.cached_dedup_works[cache_key] = CachedPokeWork()
+
+ cached_work = self.cached_dedup_works[cache_key]
+
+ if cached_work.state is not None:
+ # Have a valid cached state, don't poke twice in certain time
interval
+ self._process_sensor_work_with_cached_state(sensor_work,
cached_work.state)
+ return
+
+ try:
+ with timeout(seconds=self.poke_timeout):
+ if self.poke(sensor_work):
+ # Got a landed signal, mark all tasks waiting for this
partition
+ cached_work.set_state(PokeState.LANDED)
+
+ self._mark_multi_state(sensor_work.operator,
+ sensor_work.hashcode,
+ sensor_work.encoded_poke_context,
+ State.SUCCESS)
+
+ log.info("Task %s succeeded", str(ti_key))
+ sensor_work.close_sensor_logger()
+ else:
+ # Not landed yet. Handle possible timeout
+ cached_work.set_state(PokeState.NOT_LANDED)
+ self._check_and_handle_ti_timeout(sensor_work)
+
+ self.cached_sensor_exceptions.pop(cache_key, None)
+ except Exception as e: # pylint: disable=broad-except
+ # The retry_infra_failure decorator inside hive_hooks will raise
exception with
+ # is_infra_failure == True. Long poking timeout here is also
considered an infra
+ # failure. Other exceptions should fail.
+ is_infra_failure = getattr(e, 'is_infra_failure', False) or
isinstance(e, AirflowTaskTimeout)
+ exception_info = traceback.format_exc()
+ cached_work.set_state(PokeState.POKE_EXCEPTION)
+
+ if cache_key in self.cached_sensor_exceptions:
+ self.cached_sensor_exceptions[cache_key].set_latest_exception(
+ exception_info,
+ is_infra_failure=is_infra_failure)
+ else:
+ self.cached_sensor_exceptions[cache_key] = SensorExceptionInfo(
+ exception_info,
+ is_infra_failure=is_infra_failure)
+
+ self._handle_poke_exception(sensor_work)
+
+ def flush_cached_sensor_poke_results(self):
+ """
+ Flush outdated cached sensor states saved in previous loop.
+
+ """
+ for key, cached_work in self.cached_dedup_works.items():
+ if cached_work.is_expired():
+ self.cached_dedup_works.pop(key, None)
+ else:
+ cached_work.state = None
+
+ for ti_key, sensor_exception in self.cached_sensor_exceptions.items():
+ if sensor_exception.fail_current_run or
sensor_exception.is_expired():
+ self.cached_sensor_exceptions.pop(ti_key, None)
+
+ def poke(self, sensor_work):
+ """
+ Function that the sensors defined while deriving this class should
+ override.
+
+ """
+ cached_work = self.cached_dedup_works[sensor_work.cache_key]
+ if not cached_work.sensor_task:
+ init_args = dict(list(sensor_work.poke_context.items())
+ + [('task_id', sensor_work.task_id)])
+ operator_class = import_string(sensor_work.op_classpath)
+ cached_work.sensor_task = operator_class(**init_args)
+
+ return cached_work.sensor_task.poke(sensor_work.poke_context)
+
+ def _emit_loop_stats(self):
+ try:
+ count_poke = 0
+ count_poke_success = 0
+ count_poke_exception = 0
+ count_exception_failures = 0
+ count_infra_failure = 0
+ for cached_work in self.cached_dedup_works.values():
+ if cached_work.state is None:
+ continue
+ count_poke += 1
+ if cached_work.state == PokeState.LANDED:
+ count_poke_success += 1
+ elif cached_work.state == PokeState.POKE_EXCEPTION:
+ count_poke_exception += 1
+ for cached_exception in self.cached_sensor_exceptions.values():
+ if cached_exception.is_infra_failure and
cached_exception.fail_current_run:
+ count_infra_failure += 1
+ if cached_exception.fail_current_run:
+ count_exception_failures += 1
+
+ Stats.gauge("smart_sensor_operator.poked_tasks", count_poke)
+ Stats.gauge("smart_sensor_operator.poked_success",
count_poke_success)
+ Stats.gauge("smart_sensor_operator.poked_exception",
count_poke_exception)
+ Stats.gauge("smart_sensor_operator.exception_failures",
count_exception_failures)
+ Stats.gauge("smart_sensor_operator.infra_failures",
count_infra_failure)
+ except Exception as e: # pylint: disable=broad-except
+ self.log.exception("Exception at getting loop stats %s")
+ self.log.exception(e, exc_info=True)
+
+ def execute(self, context):
+ started_at = timezone.utcnow()
+
+ self.hostname = get_hostname()
+ while True:
+ poke_start_time = timezone.utcnow()
+
+ self.flush_cached_sensor_poke_results()
+
+ self._load_sensor_works()
+ self.log.info("Loaded %s sensor_works", len(self.sensor_works))
+ Stats.gauge("smart_sensor_operator.loaded_tasks",
len(self.sensor_works))
+
+ for sensor_work in self.sensor_works:
+ self._execute_sensor_work(sensor_work)
+
+ duration = (timezone.utcnow() - poke_start_time).total_seconds()
+
+ self.log.info("Taking %s to execute %s tasks.", duration,
len(self.sensor_works))
+
+ Stats.timing("smart_sensor_operator.loop_duration", duration)
+ Stats.gauge("smart_sensor_operator.executed_tasks",
len(self.sensor_works))
+ self._emit_loop_stats()
+
+ if duration < self.poke_interval:
+ sleep(self.poke_interval - duration)
+ if (timezone.utcnow() - started_at).total_seconds() > self.timeout:
+ self.log.info("Time is out for smart senosr.")
+ return
+
+ def on_kill(self):
+ pass
+
+
+if __name__ == '__main__':
+ SmartSensorOperator(task_id='test').execute({})
diff --git a/airflow/smart_sensor_dags/__init__.py
b/airflow/smart_sensor_dags/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/smart_sensor_dags/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/smart_sensor_dags/smart_sensor_group.py
b/airflow/smart_sensor_dags/smart_sensor_group.py
new file mode 100644
index 0000000..0187fa9
--- /dev/null
+++ b/airflow/smart_sensor_dags/smart_sensor_group.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Smart sensor DAGs managing all smart sensor tasks."""
+from builtins import range
+from datetime import timedelta
+
+from airflow.configuration import conf
+from airflow.models import DAG
+from airflow.sensors.smart_sensor_operator import SmartSensorOperator
+from airflow.utils.dates import days_ago
+
+args = {
+ 'owner': 'airflow',
+}
+
+num_smart_sensor_shard = conf.getint("smart_sensor", "shards")
+shard_code_upper_limit = conf.getint('smart_sensor', 'shard_code_upper_limit')
+
+for i in range(num_smart_sensor_shard):
+ shard_min = (i * shard_code_upper_limit) / num_smart_sensor_shard
+ shard_max = ((i + 1) * shard_code_upper_limit) / num_smart_sensor_shard
+
+ dag_id = 'smart_sensor_group_shard_{}'.format(i)
+ dag = DAG(
+ dag_id=dag_id,
+ default_args=args,
+ schedule_interval=timedelta(minutes=5),
+ concurrency=1,
+ max_active_runs=1,
+ catchup=False,
+ dagrun_timeout=timedelta(hours=24),
+ start_date=days_ago(2),
+ )
+
+ SmartSensorOperator(
+ task_id='smart_sensor_task',
+ dag=dag,
+ retries=100,
+ retry_delay=timedelta(seconds=10),
+ priority_weight=999,
+ shard_min=shard_min,
+ shard_max=shard_max,
+ poke_timeout=10,
+ smart_sensor_timeout=timedelta(hours=24).total_seconds(),
+ )
+
+ globals()[dag_id] = dag
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index d7e32ce..f2a5b9f 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -132,7 +132,9 @@ def find_path_from_directory(
def list_py_file_paths(directory: str,
safe_mode: bool = conf.getboolean('core',
'DAG_DISCOVERY_SAFE_MODE', fallback=True),
- include_examples: Optional[bool] = None):
+ include_examples: Optional[bool] = None,
+ include_smart_sensor: Optional[bool] =
+ conf.getboolean('smart_sensor', 'use_smart_sensor')):
"""
Traverse a directory and look for Python files.
@@ -145,6 +147,8 @@ def list_py_file_paths(directory: str,
:type safe_mode: bool
:param include_examples: include example DAGs
:type include_examples: bool
+ :param include_smart_sensor: include smart sensor native control DAGs
+ :type include_examples: bool
:return: a list of paths to Python files in the specified directory
:rtype: list[unicode]
"""
@@ -152,15 +156,19 @@ def list_py_file_paths(directory: str,
include_examples = conf.getboolean('core', 'LOAD_EXAMPLES')
file_paths: List[str] = []
if directory is None:
- return []
+ file_paths = []
elif os.path.isfile(directory):
- return [directory]
+ file_paths = [directory]
elif os.path.isdir(directory):
find_dag_file_paths(directory, file_paths, safe_mode)
if include_examples:
from airflow import example_dags
example_dag_folder = example_dags.__path__[0] # type: ignore
- file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode,
False))
+ file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode,
False, False))
+ if include_smart_sensor:
+ from airflow import smart_sensor_dags
+ smart_sensor_dag_folder = smart_sensor_dags.__path__[0] # type: ignore
+ file_paths.extend(list_py_file_paths(smart_sensor_dag_folder,
safe_mode, False, False))
return file_paths
diff --git a/airflow/utils/log/file_processor_handler.py
b/airflow/utils/log/file_processor_handler.py
index a552c11..d98ec85 100644
--- a/airflow/utils/log/file_processor_handler.py
+++ b/airflow/utils/log/file_processor_handler.py
@@ -76,7 +76,20 @@ class FileProcessorHandler(logging.Handler):
self.handler.close()
def _render_filename(self, filename):
- filename = os.path.relpath(filename, self.dag_dir)
+
+ # Airflow log path used to be generated by the relative path
+ # `os.path.relpath(filename, self.dag_dir)`, however the DAG
`smart_sensor_group` and
+ # all DAGs in airflow source code are not located in the DAG dir as
other DAGs.
+ # That will create a log filepath which is not under control since it
could be outside
+ # of the log dir. The change here is to make sure the log path for
DAGs in airflow code
+ # is alwasy inside the log dir as other DAGs. To be differentiate with
regular DAGs,
+ # their logs will be in the `log_dir/native_dags`.
+ import airflow
+ airflow_directory = airflow.__path__[0]
+ if filename.startswith(airflow_directory):
+ filename = os.path.join("native_dags", os.path.relpath(filename,
airflow_directory))
+ else:
+ filename = os.path.relpath(filename, self.dag_dir)
ctx = {}
ctx['filename'] = filename
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index 281a1a1..482f7ea 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -71,8 +71,15 @@ class FileTaskHandler(logging.Handler):
def _render_filename(self, ti, try_number):
if self.filename_jinja_template:
- jinja_context = ti.get_template_context()
- jinja_context['try_number'] = try_number
+ if hasattr(ti, 'task'):
+ jinja_context = ti.get_template_context()
+ jinja_context['try_number'] = try_number
+ else:
+ jinja_context = {
+ 'ti': ti,
+ 'ts': ti.execution_date.isoformat(),
+ 'try_number': try_number,
+ }
return self.filename_jinja_template.render(**jinja_context)
return self.filename_template.format(dag_id=ti.dag_id,
@@ -80,6 +87,9 @@ class FileTaskHandler(logging.Handler):
execution_date=ti.execution_date.isoformat(),
try_number=try_number)
+ def _read_grouped_logs(self):
+ return False
+
def _read(self, ti, try_number, metadata=None): # pylint:
disable=unused-argument
"""
Template method that contains custom logic of reading
@@ -168,7 +178,7 @@ class FileTaskHandler(logging.Handler):
it returns all logs separated by try_number
:param metadata: log metadata,
can be used for steaming log reading and auto-tailing.
- :return: a list of logs
+ :return: a list of listed tuples which order log string by host
"""
# Task instance increments its try number when it starts to run.
# So the log for a particular task try will only show up when
@@ -180,7 +190,7 @@ class FileTaskHandler(logging.Handler):
try_numbers = list(range(1, next_try))
elif try_number < 1:
logs = [
- 'Error fetching the logs. Try number {} is
invalid.'.format(try_number),
+ [('default_host', 'Error fetching the logs. Try number {} is
invalid.'.format(try_number))],
]
return logs
else:
@@ -190,7 +200,9 @@ class FileTaskHandler(logging.Handler):
metadata_array = [{}] * len(try_numbers)
for i, try_number_element in enumerate(try_numbers):
log, metadata = self._read(task_instance, try_number_element,
metadata)
- logs[i] += log
+ # es_task_handler return logs grouped by host. wrap other handler
returning log string
+ # with default/ empty host so that UI can render the response in
the same way
+ logs[i] = log if self._read_grouped_logs() else
[(task_instance.hostname, log)]
metadata_array[i] = metadata
return logs, metadata_array
diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py
index dbfe476..b54e901 100644
--- a/airflow/utils/log/log_reader.py
+++ b/airflow/utils/log/log_reader.py
@@ -84,7 +84,8 @@ class TaskLogReader:
metadata.pop('offset', None)
while 'end_of_log' not in metadata or not metadata['end_of_log']:
logs, metadata = self.read_log_chunks(ti, current_try_number,
metadata)
- yield "\n".join(logs) + "\n"
+ for host, log in logs[0]:
+ yield "\n".join([host, log]) + "\n"
@cached_property
def log_handler(self):
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index b2abbff..4999cd8 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -43,6 +43,7 @@ class State:
UP_FOR_RESCHEDULE = "up_for_reschedule"
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
+ SENSING = "sensing"
task_states = (
SUCCESS,
@@ -55,6 +56,7 @@ class State:
QUEUED,
NONE,
SCHEDULED,
+ SENSING,
)
dag_states = (
@@ -76,6 +78,7 @@ class State:
REMOVED: 'lightgrey',
SCHEDULED: 'tan',
NONE: 'lightblue',
+ SENSING: 'lightseagreen',
}
state_color.update(STATE_COLORS) # type: ignore
@@ -97,6 +100,16 @@ class State:
return 'black'
@classmethod
+ def running(cls):
+ """
+ A list of states indicating that a task is being executed.
+ """
+ return [
+ cls.RUNNING,
+ cls.SENSING
+ ]
+
+ @classmethod
def finished(cls):
"""
A list of states indicating that a task started and completed a
@@ -120,7 +133,17 @@ class State:
cls.SCHEDULED,
cls.QUEUED,
cls.RUNNING,
+ cls.SENSING,
cls.SHUTDOWN,
cls.UP_FOR_RETRY,
- cls.UP_FOR_RESCHEDULE
+ cls.UP_FOR_RESCHEDULE,
]
+
+
+class PokeState:
+ """
+ Static class with poke states constants used in smart operator.
+ """
+ LANDED = 'landed'
+ NOT_LANDED = 'not_landed'
+ POKE_EXCEPTION = 'poke_exception'
diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css
index cd9bedf..7f1b818 100644
--- a/airflow/www/static/css/graph.css
+++ b/airflow/www/static/css/graph.css
@@ -63,6 +63,10 @@ g.node.up_for_reschedule rect {
stroke: turquoise;
}
+g.node.sensing rect {
+ stroke: lightseagreen;
+}
+
g.node.queued rect {
stroke: grey;
}
diff --git a/airflow/www/static/css/tree.css b/airflow/www/static/css/tree.css
index b26f6ef..e69444a 100644
--- a/airflow/www/static/css/tree.css
+++ b/airflow/www/static/css/tree.css
@@ -83,6 +83,10 @@ rect.skipped {
fill: pink;
}
+rect.sensing {
+ fill: lightseagreen;
+}
+
.tooltip.in {
opacity: 1;
filter: alpha(opacity=100);
diff --git a/airflow/www/templates/airflow/ti_log.html
b/airflow/www/templates/airflow/ti_log.html
index 1dcf6d3..7b9af69 100644
--- a/airflow/www/templates/airflow/ti_log.html
+++ b/airflow/www/templates/airflow/ti_log.html
@@ -44,7 +44,7 @@
<div role="tabpanel" class="tab-pane {{ 'active' if loop.last else '' }}"
id="{{ loop.index }}">
<img id="loading-{{ loop.index }}" style="margin-top:0%; margin-left:50%;
height:50px; width:50px; position: absolute;"
alt="spinner" src="{{ url_for('static', filename='loading.gif') }}">
- <pre><code id="try-{{ loop.index }}" class="{{ 'wrap' if wrapped else ''
}}">{{ log }}</code></pre>
+ <div id="log-group-{{ loop.index }}"></div>
</div>
{% endfor %}
</div>
@@ -122,16 +122,30 @@
if(auto_tailing && checkAutoTailingCondition()) {
var should_scroll = true;
}
- // The message may contain HTML, so either have to escape it or
write it as text.
- var escaped_message = escapeHtml(res.message);
// Detect urls
var url_regex =
/http(s)?:\/\/[\w\.\-]+(\.?:[\w\.\-]+)*([\/?#][\w\-\._~:/?#[\]@!\$&'\(\)\*\+,;=\.%]+)?/g;
- var linkified_message = escaped_message.replace(url_regex,
function(url) {
- return "<a href=\"" + url + "\" target=\"_blank\">" + url +
"</a>";
- });
- document.getElementById(`try-${try_number}`).innerHTML +=
linkified_message + "<br/>";
+ res.message.forEach(function(item, index){
+ var log_block_element_id = "try-" + try_number + "-" + item[0];
+ var log_block = document.getElementById(log_block_element_id);
+ if (!log_block) {
+ log_div_block = document.createElement('div');
+ log_pre_block = document.createElement('pre');
+ log_div_block.appendChild(log_pre_block);
+ log_pre_block.innerHTML = "<code id=\"" + log_block_element_id
+ "\" ></code>";
+ document.getElementById("log-group-" +
try_number).appendChild(log_div_block);
+ log_block = document.getElementById(log_block_element_id);
+ }
+
+ // The message may contain HTML, so either have to escape it or
write it as text.
+ var escaped_message = escapeHtml(item[1]);
+ var linkified_message = escaped_message.replace(url_regex,
function(url) {
+ return "<a href=\"" + url + "\" target=\"_blank\">" + url +
"</a>";
+ });
+ log_block.innerHTML += linkified_message + "\n";
+ })
+
// Auto scroll window to the end if current window location is
near the end.
if(should_scroll) {
scrollBottom();
diff --git a/docs/img/smart_sensor_architecture.png
b/docs/img/smart_sensor_architecture.png
new file mode 100644
index 0000000..4fdf3e9
Binary files /dev/null and b/docs/img/smart_sensor_architecture.png differ
diff --git a/docs/img/smart_sensor_single_task_execute_flow.png
b/docs/img/smart_sensor_single_task_execute_flow.png
new file mode 100644
index 0000000..c3ec2e0
Binary files /dev/null and b/docs/img/smart_sensor_single_task_execute_flow.png
differ
diff --git a/docs/index.rst b/docs/index.rst
index f532d9a..45ea799 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -95,6 +95,7 @@ Content
lineage
dag-serialization
modules_management
+ smart-sensor
changelog
best-practices
faq
diff --git a/docs/logging-monitoring/metrics.rst
b/docs/logging-monitoring/metrics.rst
index 491f2fa..afe7801 100644
--- a/docs/logging-monitoring/metrics.rst
+++ b/docs/logging-monitoring/metrics.rst
@@ -112,6 +112,11 @@ Name
Description
``pool.queued_slots.<pool_name>`` Number of queued slots in
the pool
``pool.running_slots.<pool_name>`` Number of running slots in
the pool
``pool.starving_tasks.<pool_name>`` Number of starving tasks
in the pool
+``smart_sensor_operator.poked_tasks`` Number of tasks poked by
the smart sensor in the previous poking loop
+``smart_sensor_operator.poked_success`` Number of newly succeeded
tasks poked by the smart sensor in the previous poking loop
+``smart_sensor_operator.poked_exception`` Number of exceptions in
the previous smart sensor poking loop
+``smart_sensor_operator.exception_failures`` Number of failures caused
by exception in the previous smart sensor poking loop
+``smart_sensor_operator.infra_failures`` Number of infrastructure
failures in the previous smart sensor poking loop
===================================================
========================================================================
Timers
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 4f7c9e7..9b2f03f 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -131,6 +131,10 @@ Fundamentals
* - :mod:`airflow.hooks.filesystem`
-
+ * - :mod:`airflow.sensors.smart_sensor_operator`
+ -
+
+
.. _Apache:
ASF: Apache Software Foundation
diff --git a/docs/smart-sensor.rst b/docs/smart-sensor.rst
new file mode 100644
index 0000000..2664f8b
--- /dev/null
+++ b/docs/smart-sensor.rst
@@ -0,0 +1,86 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+
+Smart Sensor
+============
+
+The smart sensor is a service (run by a builtin DAG) which greatly reduces
airflow’s infrastructure
+cost by consolidating some of the airflow long running light weight tasks.
+
+.. image:: img/smart_sensor_architecture.png
+
+Instead of using one process for each task, the main idea of the smart sensor
service is to improve the
+efficiency of these long running tasks by using centralized processes to
execute those tasks in batches.
+
+To do that, we need to run a task in two steps, the first step is to serialize
the task information
+into the database; and the second step is to use a few centralized processes
to execute the serialized
+tasks in batches.
+
+In this way, we only need a handful of running processes.
+
+.. image:: img/smart_sensor_single_task_execute_flow.png
+
+The smart sensor service is supported in a new mode called “smart sensor
mode”. In smart sensor mode,
+instead of holding a long running process for each sensor and poking
periodically, a sensor will only
+store poke context at sensor_instance table and then exits with a ‘sensing’
state.
+
+When the smart sensor mode is enabled, a special set of builtin smart sensor
DAGs
+(named smart_sensor_group_shard_xxx) is created by the system; These DAGs
contain ``SmartSensorOperator``
+task and manage the smart sensor jobs for the airflow cluster. The
SmartSensorOperator task can fetch
+hundreds of ‘sensing’ instances from sensor_instance table and poke on behalf
of them in batches.
+Users don’t need to change their existing DAGs.
+
+Enable/Disable Smart Sensor
+---------------------------
+
+Updating from a older version might need a schema change. If there is no
``sensor_instance`` table
+in the DB, please make sure to run ``airflow db upgrade``
+
+Add the following settings in the ``airflow.cfg``:
+
+.. code-block::
+
+ [smart_sensor]
+ use_smart_sensor = true
+ shard_code_upper_limit = 10000
+
+ # Users can change the following config based on their requirements
+ shards = 5
+ sensor_enabled = NamedHivePartitionSensor, MetastorePartitionSensor
+
+* ``use_smart_sensor``: This config indicates if the smart sensor is enabled.
+* ``shards``: This config indicates the number of concurrently running smart
sensor jobs for
+ the airflow cluster.
+* ``sensor_enabled``: This config is a list of sensor class names that will
use the smart sensor.
+ The users use the same class names (e.g. HivePartitionSensor) in their
DAGs and they don’t have
+ the control to use smart sensors or not, unless they exclude their tasks
explicitly.
+
+Enabling/disabling the smart sensor service is a system level configuration
change.
+It is transparent to the individual users. Existing DAGs don't need to be
changed for
+enabling/disabling the smart sensor. Rotating centralized smart sensor tasks
will not
+cause any user’s sensor task failure.
+
+Support new operators in the smart sensor service
+-------------------------------------------------
+
+* Define ``poke_context_fields`` as class attribute in the sensor.
``poke_context_fields``
+ include all key names used for initializing a sensor object.
+* In ``airflow.cfg``, add the new operator's classname to ``[smart_sensor]
sensors_enabled``.
+ All supported sensors' classname should be comma separated.
diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py
b/tests/api_connexion/endpoints/test_log_endpoint.py
index 2921a90..ff41609 100644
--- a/tests/api_connexion/endpoints/test_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_log_endpoint.py
@@ -158,7 +158,8 @@ class TestGetLog(unittest.TestCase):
self.log_dir, self.DAG_ID, self.TASK_ID,
self.default_time.replace(":", ".")
)
self.assertEqual(
- response.json['content'], f"*** Reading local file:
{expected_filename}\nLog for testing."
+ response.json['content'],
+ f"[('', '*** Reading local file: {expected_filename}\\nLog for
testing.')]",
)
info = serializer.loads(response.json['continuation_token'])
self.assertEqual(info, {'end_of_log': True})
@@ -182,7 +183,8 @@ class TestGetLog(unittest.TestCase):
)
self.assertEqual(200, response.status_code)
self.assertEqual(
- response.data.decode('utf-8'), f"*** Reading local file:
{expected_filename}\nLog for testing.\n"
+ response.data.decode('utf-8'),
+ f"\n*** Reading local file: {expected_filename}\nLog for
testing.\n",
)
@provide_session
@@ -204,10 +206,10 @@ class TestGetLog(unittest.TestCase):
def test_get_logs_with_metadata_as_download_large_file(self, session):
self._create_dagrun(session)
with
mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as
read_mock:
- first_return = (['1st line'], [{}])
- second_return = (['2nd line'], [{'end_of_log': False}])
- third_return = (['3rd line'], [{'end_of_log': True}])
- fourth_return = (['should never be read'], [{'end_of_log': True}])
+ first_return = ([[('', '1st line')]], [{}])
+ second_return = ([[('', '2nd line')]], [{'end_of_log': False}])
+ third_return = ([[('', '3rd line')]], [{'end_of_log': True}])
+ fourth_return = ([[('', 'should never be read')]], [{'end_of_log':
True}])
read_mock.side_effect = [first_return, second_return,
third_return, fourth_return]
response = self.client.get(
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 00325fe..030ff3b 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -33,6 +33,7 @@ from mock import MagicMock, patch
from parameterized import parameterized
import airflow.example_dags
+import airflow.smart_sensor_dags
from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -197,7 +198,9 @@ class TestDagFileProcessor(unittest.TestCase):
executor = MockExecutor(do_update=True, parallelism=3)
with create_session() as session:
- dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER,
"no_dags.py"))
+ dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER,
"no_dags.py"),
+ include_examples=False,
+ include_smart_sensor=False)
dag = self.create_test_dag()
dag.clear()
dagbag.bag_dag(dag=dag, root_dag=dag)
@@ -1309,7 +1312,9 @@ class TestDagFileProcessorQueriesCount(unittest.TestCase):
}), conf_vars({
('scheduler', 'use_job_schedule'): 'True',
}):
- dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE,
include_examples=False)
+ dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE,
+ include_examples=False,
+ include_smart_sensor=False)
processor = DagFileProcessor([], mock.MagicMock())
for expected_query_count in expected_query_counts:
with assert_queries_count(expected_query_count):
@@ -3298,6 +3303,19 @@ class TestSchedulerJob(unittest.TestCase):
detected_files.add(file_path)
self.assertEqual(detected_files, expected_files)
+ smart_sensor_dag_folder = airflow.smart_sensor_dags.__path__[0]
+ for root, _, files in os.walk(smart_sensor_dag_folder):
+ for file_name in files:
+ if (file_name.endswith('.py') or file_name.endswith('.zip'))
and \
+ file_name not in ['__init__.py']:
+ expected_files.add(os.path.join(root, file_name))
+ detected_files.clear()
+ for file_path in list_py_file_paths(TEST_DAG_FOLDER,
+ include_examples=True,
+ include_smart_sensor=True):
+ detected_files.add(file_path)
+ self.assertEqual(detected_files, expected_files)
+
def test_reset_orphaned_tasks_nothing(self):
"""Try with nothing. """
scheduler = SchedulerJob()
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 9173ad1..c2fe045 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -728,7 +728,9 @@ class TestDagBag(unittest.TestCase):
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")
- dagbag = DagBag(dag_folder=dag_file)
+ dagbag = DagBag(dag_folder=dag_file,
+ include_smart_sensor=False,
+ include_examples=False)
self.assertEqual(set(), set(dagbag.dag_ids))
expected_import_errors = {
dag_file: (
@@ -747,7 +749,9 @@ class TestDagBag(unittest.TestCase):
dag_file = os.path.join(TEST_DAGS_FOLDER,
"test_with_non_default_owner.py")
- dagbag = DagBag(dag_folder=dag_file)
+ dagbag = DagBag(dag_folder=dag_file,
+ include_examples=False,
+ include_smart_sensor=False)
self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids))
self.assertEqual({}, dagbag.import_errors)
diff --git a/tests/models/test_sensorinstance.py
b/tests/models/test_sensorinstance.py
new file mode 100644
index 0000000..168d97f
--- /dev/null
+++ b/tests/models/test_sensorinstance.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow.models import SensorInstance
+from airflow.providers.apache.hive.sensors.named_hive_partition import
NamedHivePartitionSensor
+from airflow.sensors.python import PythonSensor
+
+
+class SensorInstanceTest(unittest.TestCase):
+
+ def test_get_classpath(self):
+ # Test the classpath in/out airflow
+ obj1 = NamedHivePartitionSensor(
+ partition_names=['test_partition'],
+ task_id='meta_partition_test_1')
+ obj1_classpath = SensorInstance.get_classpath(obj1)
+ obj1_importpath = "airflow.providers.apache.hive." \
+
"sensors.named_hive_partition.NamedHivePartitionSensor"
+
+ self.assertEqual(obj1_classpath, obj1_importpath)
+
+ def test_callable():
+ return
+ obj3 = PythonSensor(python_callable=test_callable,
+ task_id='python_sensor_test')
+ obj3_classpath = SensorInstance.get_classpath(obj3)
+ obj3_importpath = "airflow.sensors.python.PythonSensor"
+
+ self.assertEqual(obj3_classpath, obj3_importpath)
diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
index 7509be2..9eedd6b 100644
--- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -127,7 +127,10 @@ class TestCloudwatchTaskHandler(unittest.TestCase):
)
self.assertEqual(
self.cloudwatch_task_handler.read(self.ti),
- ([expected.format(self.remote_log_group, self.remote_log_stream)],
[{'end_of_log': True}]),
+ (
+ [[('', expected.format(self.remote_log_group,
self.remote_log_stream))]],
+ [{'end_of_log': True}],
+ ),
)
def test_read_wrong_log_stream(self):
@@ -149,7 +152,7 @@ class TestCloudwatchTaskHandler(unittest.TestCase):
self.assertEqual(
self.cloudwatch_task_handler.read(self.ti),
(
- [msg_template.format(self.remote_log_group,
self.remote_log_stream, error_msg)],
+ [[('', msg_template.format(self.remote_log_group,
self.remote_log_stream, error_msg))]],
[{'end_of_log': True}],
),
)
@@ -173,7 +176,7 @@ class TestCloudwatchTaskHandler(unittest.TestCase):
self.assertEqual(
self.cloudwatch_task_handler.read(self.ti),
(
- [msg_template.format(self.remote_log_group,
self.remote_log_stream, error_msg)],
+ [[('', msg_template.format(self.remote_log_group,
self.remote_log_stream, error_msg))]],
[{'end_of_log': True}],
),
)
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py
b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index 4c737e0..2c0eebf 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -128,20 +128,19 @@ class TestS3TaskHandler(unittest.TestCase):
def test_read(self):
self.conn.put_object(Bucket='bucket', Key=self.remote_log_key,
Body=b'Log line\n')
+ log, metadata = self.s3_task_handler.read(self.ti)
self.assertEqual(
- self.s3_task_handler.read(self.ti),
- (
- ['*** Reading remote log from
s3://bucket/remote/log/location/1.log.\nLog line\n\n'],
- [{'end_of_log': True}],
- ),
+ log[0][0][-1],
+ '*** Reading remote log from
s3://bucket/remote/log/location/1.log.\n' 'Log line\n\n',
)
+ self.assertEqual(metadata, [{'end_of_log': True}])
def test_read_when_s3_log_missing(self):
log, metadata = self.s3_task_handler.read(self.ti)
self.assertEqual(1, len(log))
self.assertEqual(len(log), len(metadata))
- self.assertIn('*** Log file does not exist:', log[0])
+ self.assertIn('*** Log file does not exist:', log[0][0][-1])
self.assertEqual({'end_of_log': True}, metadata[0])
def test_read_raises_return_error(self):
diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py
b/tests/providers/elasticsearch/log/test_es_task_handler.py
index deec540..5423df3 100644
--- a/tests/providers/elasticsearch/log/test_es_task_handler.py
+++ b/tests/providers/elasticsearch/log/test_es_task_handler.py
@@ -114,7 +114,8 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0])
+ self.assertEqual(len(logs[0]), 1)
+ self.assertEqual(self.test_message, logs[0][0][-1])
self.assertFalse(metadatas[0]['end_of_log'])
self.assertEqual('1', metadatas[0]['offset'])
self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) >
ts)
@@ -134,7 +135,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
)
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0])
+ self.assertEqual(self.test_message, logs[0][0][-1])
self.assertNotEqual(another_test_message, logs[0])
self.assertFalse(metadatas[0]['end_of_log'])
@@ -145,7 +146,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
logs, metadatas = self.es_task_handler.read(self.ti, 1)
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0])
+ self.assertEqual(self.test_message, logs[0][0][-1])
self.assertFalse(metadatas[0]['end_of_log'])
self.assertEqual('1', metadatas[0]['offset'])
self.assertTrue(timezone.parse(metadatas[0]['last_log_timestamp']) <
pendulum.now())
@@ -161,7 +162,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
)
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([''], logs)
+ self.assertEqual([[]], logs)
self.assertFalse(metadatas[0]['end_of_log'])
self.assertEqual('0', metadatas[0]['offset'])
# last_log_timestamp won't change if no log lines read.
@@ -172,7 +173,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
logs, metadatas = self.es_task_handler.read(self.ti, 1, {})
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0])
+ self.assertEqual(self.test_message, logs[0][0][-1])
self.assertFalse(metadatas[0]['end_of_log'])
# offset should be initialized to 0 if not provided.
self.assertEqual('1', metadatas[0]['offset'])
@@ -185,7 +186,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
logs, metadatas = self.es_task_handler.read(self.ti, 1, {'end_of_log':
False})
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([''], logs)
+ self.assertEqual([[]], logs)
self.assertFalse(metadatas[0]['end_of_log'])
# offset should be initialized to 0 if not provided.
self.assertEqual('0', metadatas[0]['offset'])
@@ -202,7 +203,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
)
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([''], logs)
+ self.assertEqual([[]], logs)
self.assertTrue(metadatas[0]['end_of_log'])
# offset should be initialized to 0 if not provided.
self.assertEqual('0', metadatas[0]['offset'])
@@ -217,7 +218,8 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
)
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual(self.test_message, logs[0])
+ self.assertEqual(len(logs[0]), 1)
+ self.assertEqual(self.test_message, logs[0][0][-1])
self.assertFalse(metadatas[0]['end_of_log'])
self.assertTrue(metadatas[0]['download_logs'])
self.assertEqual('1', metadatas[0]['offset'])
@@ -234,7 +236,7 @@ class TestElasticsearchTaskHandler(unittest.TestCase):
self.assertEqual(1, len(logs))
self.assertEqual(len(logs), len(metadatas))
- self.assertEqual([''], logs)
+ self.assertEqual([[]], logs)
self.assertFalse(metadatas[0]['end_of_log'])
self.assertEqual('0', metadatas[0]['offset'])
diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
index 9a51ec5..d5bf169 100644
--- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
+++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
@@ -101,7 +101,15 @@ class TestWasbTaskHandler(unittest.TestCase):
self.assertEqual(
self.wasb_task_handler.read(self.ti),
(
- ['*** Reading remote log from
wasb://container/remote/log/location/1.log.\nLog line\n'],
+ [
+ [
+ (
+ '',
+ '*** Reading remote log from
wasb://container/remote/log/location/1.log.\n'
+ 'Log line\n',
+ )
+ ]
+ ],
[{'end_of_log': True}],
),
)
diff --git a/tests/sensors/test_smart_sensor_operator.py
b/tests/sensors/test_smart_sensor_operator.py
new file mode 100644
index 0000000..28eeee6
--- /dev/null
+++ b/tests/sensors/test_smart_sensor_operator.py
@@ -0,0 +1,326 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import logging
+import os
+import time
+import unittest
+
+from freezegun import freeze_time
+from mock import Mock
+
+from airflow import DAG, settings
+from airflow.configuration import conf
+from airflow.models import DagRun, SensorInstance, TaskInstance
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.sensors.smart_sensor_operator import SmartSensorOperator
+from airflow.utils import timezone
+from airflow.utils.state import State
+
+DEFAULT_DATE = timezone.datetime(2015, 1, 1)
+TEST_DAG_ID = 'unit_test_dag'
+TEST_SENSOR_DAG_ID = 'unit_test_sensor_dag'
+DUMMY_OP = 'dummy_op'
+SMART_OP = 'smart_op'
+SENSOR_OP = 'sensor_op'
+
+
+class DummySmartSensor(SmartSensorOperator):
+ def __init__(self,
+ shard_max=conf.getint('smart_sensor',
'shard_code_upper_limit'),
+ shard_min=0,
+ **kwargs):
+ super(DummySmartSensor, self).__init__(shard_min=shard_min,
+ shard_max=shard_max,
+ **kwargs)
+
+
+class DummySensor(BaseSensorOperator):
+ poke_context_fields = ('input_field', 'return_value')
+ exec_fields = ('soft_fail', 'execution_timeout', 'timeout')
+
+ def __init__(self, input_field='test', return_value=False, **kwargs):
+ super(DummySensor, self).__init__(**kwargs)
+ self.input_field = input_field
+ self.return_value = return_value
+
+ def poke(self, context):
+ return context.get('return_value', False)
+
+ def is_smart_sensor_compatible(self):
+ return not self.on_failure_callback
+
+
+class SmartSensorTest(unittest.TestCase):
+ def setUp(self):
+ os.environ['AIRFLOW__SMART_SENSER__USE_SMART_SENSOR'] = 'true'
+ os.environ['AIRFLOW__SMART_SENSER__SENSORS_ENABLED'] =
'DummySmartSensor'
+
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
+ }
+ self.dag = DAG(TEST_DAG_ID, default_args=args)
+ self.sensor_dag = DAG(TEST_SENSOR_DAG_ID, default_args=args)
+ self.log = logging.getLogger('BaseSmartTest')
+
+ session = settings.Session()
+ session.query(DagRun).delete()
+ session.query(TaskInstance).delete()
+ session.query(SensorInstance).delete()
+ session.commit()
+
+ def tearDown(self):
+ session = settings.Session()
+ session.query(DagRun).delete()
+ session.query(TaskInstance).delete()
+ session.query(SensorInstance).delete()
+ session.commit()
+
+ os.environ.pop('AIRFLOW__SMART_SENSER__USE_SMART_SENSOR')
+ os.environ.pop('AIRFLOW__SMART_SENSER__SENSORS_ENABLED')
+
+ def _make_dag_run(self):
+ return self.dag.create_dagrun(
+ run_id='manual__' + TEST_DAG_ID,
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ def _make_sensor_dag_run(self):
+ return self.sensor_dag.create_dagrun(
+ run_id='manual__' + TEST_SENSOR_DAG_ID,
+ start_date=timezone.utcnow(),
+ execution_date=DEFAULT_DATE,
+ state=State.RUNNING
+ )
+
+ def _make_sensor(self, return_value, **kwargs):
+ poke_interval = 'poke_interval'
+ timeout = 'timeout'
+ if poke_interval not in kwargs:
+ kwargs[poke_interval] = 0
+ if timeout not in kwargs:
+ kwargs[timeout] = 0
+
+ sensor = DummySensor(
+ task_id=SENSOR_OP,
+ return_value=return_value,
+ dag=self.sensor_dag,
+ **kwargs
+ )
+
+ return sensor
+
+ def _make_sensor_instance(self, index, return_value, **kwargs):
+ poke_interval = 'poke_interval'
+ timeout = 'timeout'
+ if poke_interval not in kwargs:
+ kwargs[poke_interval] = 0
+ if timeout not in kwargs:
+ kwargs[timeout] = 0
+
+ task_id = SENSOR_OP + str(index)
+ sensor = DummySensor(
+ task_id=task_id,
+ return_value=return_value,
+ dag=self.sensor_dag,
+ **kwargs
+ )
+
+ ti = TaskInstance(task=sensor, execution_date=DEFAULT_DATE)
+
+ return ti
+
+ def _make_smart_operator(self, index, **kwargs):
+ poke_interval = 'poke_interval'
+ smart_sensor_timeout = 'smart_sensor_timeout'
+ if poke_interval not in kwargs:
+ kwargs[poke_interval] = 0
+ if smart_sensor_timeout not in kwargs:
+ kwargs[smart_sensor_timeout] = 0
+
+ smart_task = DummySmartSensor(
+ task_id=SMART_OP + "_" + str(index),
+ dag=self.dag,
+ **kwargs
+ )
+
+ dummy_op = DummyOperator(
+ task_id=DUMMY_OP,
+ dag=self.dag
+ )
+ dummy_op.set_upstream(smart_task)
+ return smart_task
+
+ @classmethod
+ def _run(cls, task):
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+
+ def test_load_sensor_works(self):
+ # Mock two sensor tasks return True and one return False
+ # The hashcode for si1 and si2 should be same. Test dedup on these two
instances
+ si1 = self._make_sensor_instance(1, True)
+ si2 = self._make_sensor_instance(2, True)
+ si3 = self._make_sensor_instance(3, False)
+
+ # Confirm initial state
+ smart = self._make_smart_operator(0)
+ smart.flush_cached_sensor_poke_results()
+ self.assertEqual(len(smart.cached_dedup_works), 0)
+ self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+
+ si1.run(ignore_all_deps=True)
+ # Test single sensor
+ smart._load_sensor_works()
+ self.assertEqual(len(smart.sensor_works), 1)
+ self.assertEqual(len(smart.cached_dedup_works), 0)
+ self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+
+ si2.run(ignore_all_deps=True)
+ si3.run(ignore_all_deps=True)
+
+ # Test multiple sensors with duplication
+ smart._load_sensor_works()
+ self.assertEqual(len(smart.sensor_works), 3)
+ self.assertEqual(len(smart.cached_dedup_works), 0)
+ self.assertEqual(len(smart.cached_sensor_exceptions), 0)
+
+ def test_execute_single_task_with_dup(self):
+ sensor_dr = self._make_sensor_dag_run()
+ si1 = self._make_sensor_instance(1, True)
+ si2 = self._make_sensor_instance(2, True)
+ si3 = self._make_sensor_instance(3, False, timeout=0)
+
+ si1.run(ignore_all_deps=True)
+ si2.run(ignore_all_deps=True)
+ si3.run(ignore_all_deps=True)
+
+ smart = self._make_smart_operator(0)
+ smart.flush_cached_sensor_poke_results()
+
+ smart._load_sensor_works()
+ self.assertEqual(len(smart.sensor_works), 3)
+
+ for sensor_work in smart.sensor_works:
+ _, task_id, _ = sensor_work.ti_key
+ if task_id == SENSOR_OP + "1":
+ smart._execute_sensor_work(sensor_work)
+ break
+
+ self.assertEqual(len(smart.cached_dedup_works), 1)
+
+ tis = sensor_dr.get_task_instances()
+ for ti in tis:
+ if ti.task_id == SENSOR_OP + "1":
+ self.assertEqual(ti.state, State.SUCCESS)
+ if ti.task_id == SENSOR_OP + "2":
+ self.assertEqual(ti.state, State.SUCCESS)
+ if ti.task_id == SENSOR_OP + "3":
+ self.assertEqual(ti.state, State.SENSING)
+
+ for sensor_work in smart.sensor_works:
+ _, task_id, _ = sensor_work.ti_key
+ if task_id == SENSOR_OP + "2":
+ smart._execute_sensor_work(sensor_work)
+ break
+
+ self.assertEqual(len(smart.cached_dedup_works), 1)
+
+ time.sleep(1)
+ for sensor_work in smart.sensor_works:
+ _, task_id, _ = sensor_work.ti_key
+ if task_id == SENSOR_OP + "3":
+ smart._execute_sensor_work(sensor_work)
+ break
+
+ self.assertEqual(len(smart.cached_dedup_works), 2)
+
+ tis = sensor_dr.get_task_instances()
+ for ti in tis:
+ # Timeout=0, the Failed poke lead to task fail
+ if ti.task_id == SENSOR_OP + "3":
+ self.assertEqual(ti.state, State.FAILED)
+
+ def test_smart_operator_timeout(self):
+ sensor_dr = self._make_sensor_dag_run()
+ si1 = self._make_sensor_instance(1, False, timeout=10)
+ smart = self._make_smart_operator(0, poke_interval=6)
+ smart.poke = Mock(side_effect=[False, False, False, False])
+
+ date1 = timezone.utcnow()
+ with freeze_time(date1):
+ si1.run(ignore_all_deps=True)
+ smart.flush_cached_sensor_poke_results()
+ smart._load_sensor_works()
+
+ for sensor_work in smart.sensor_works:
+ smart._execute_sensor_work(sensor_work)
+
+ # Before timeout the state should be SENSING
+ sis = sensor_dr.get_task_instances()
+ for sensor_instance in sis:
+ if sensor_instance.task_id == SENSOR_OP + "1":
+ self.assertEqual(sensor_instance.state, State.SENSING)
+
+ date2 = date1 + datetime.timedelta(seconds=smart.poke_interval)
+ with freeze_time(date2):
+ smart.flush_cached_sensor_poke_results()
+ smart._load_sensor_works()
+
+ for sensor_work in smart.sensor_works:
+ smart._execute_sensor_work(sensor_work)
+
+ sis = sensor_dr.get_task_instances()
+ for sensor_instance in sis:
+ if sensor_instance.task_id == SENSOR_OP + "1":
+ self.assertEqual(sensor_instance.state, State.SENSING)
+
+ date3 = date2 + datetime.timedelta(seconds=smart.poke_interval)
+ with freeze_time(date3):
+ smart.flush_cached_sensor_poke_results()
+ smart._load_sensor_works()
+
+ for sensor_work in smart.sensor_works:
+ smart._execute_sensor_work(sensor_work)
+
+ sis = sensor_dr.get_task_instances()
+ for sensor_instance in sis:
+ if sensor_instance.task_id == SENSOR_OP + "1":
+ self.assertEqual(sensor_instance.state, State.FAILED)
+
+ def test_register_in_sensor_service(self):
+ si1 = self._make_sensor_instance(1, True)
+ si1.run(ignore_all_deps=True)
+ self.assertEqual(si1.state, State.SENSING)
+
+ session = settings.Session()
+
+ SI = SensorInstance
+ sensor_instance = session.query(SI).filter(
+ SI.dag_id == si1.dag_id,
+ SI.task_id == si1.task_id,
+ SI.execution_date == si1.execution_date) \
+ .first()
+
+ self.assertIsNotNone(sensor_instance)
+ self.assertEqual(sensor_instance.state, State.SENSING)
+ self.assertEqual(sensor_instance.operator, "DummySensor")
diff --git a/tests/test_config_templates.py b/tests/test_config_templates.py
index c8fea0b..7b513d9 100644
--- a/tests/test_config_templates.py
+++ b/tests/test_config_templates.py
@@ -53,7 +53,8 @@ DEFAULT_AIRFLOW_SECTIONS = [
'kubernetes_node_selectors',
'kubernetes_environment_variables',
'kubernetes_secrets',
- 'kubernetes_labels'
+ 'kubernetes_labels',
+ 'smart_sensor'
]
DEFAULT_TEST_SECTIONS = [
diff --git a/tests/utils/log/test_log_reader.py
b/tests/utils/log/test_log_reader.py
index 0fe2a7a..fc65496 100644
--- a/tests/utils/log/test_log_reader.py
+++ b/tests/utils/log/test_log_reader.py
@@ -105,11 +105,12 @@ class TestLogView(unittest.TestCase):
self.assertEqual(
[
- f"*** Reading local file: "
-
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- f"try_number=1.\n"
+ ('',
+ f"*** Reading local file: "
+
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ f"try_number=1.\n")
],
- logs,
+ logs[0],
)
self.assertEqual({"end_of_log": True}, metadatas)
@@ -119,15 +120,18 @@ class TestLogView(unittest.TestCase):
self.assertEqual(
[
- "*** Reading local file: "
-
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
- "try_number=1.\n",
- f"*** Reading local file: "
-
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
- f"try_number=2.\n",
- f"*** Reading local file: "
-
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
- f"try_number=3.\n",
+ [('',
+ "*** Reading local file: "
+
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
+ "try_number=1.\n")],
+ [('',
+ f"*** Reading local file: "
+
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
+ f"try_number=2.\n")],
+ [('',
+ f"*** Reading local file: "
+
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
+ f"try_number=3.\n")],
],
logs,
)
@@ -139,7 +143,7 @@ class TestLogView(unittest.TestCase):
self.assertEqual(
[
- "*** Reading local file: "
+ "\n*** Reading local file: "
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
"try_number=1.\n"
"\n"
@@ -152,15 +156,15 @@ class TestLogView(unittest.TestCase):
stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None,
metadata={})
self.assertEqual(
[
- "*** Reading local file: "
+ "\n*** Reading local file: "
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log\n"
"try_number=1.\n"
"\n",
- "*** Reading local file: "
+ "\n*** Reading local file: "
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/2.log\n"
"try_number=2.\n"
"\n",
- "*** Reading local file: "
+ "\n*** Reading local file: "
f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log\n"
"try_number=3.\n"
"\n",
@@ -170,15 +174,15 @@ class TestLogView(unittest.TestCase):
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_should_support_multiple_chunks(self, mock_read):
- first_return = (["1st line"], [{}])
- second_return = (["2nd line"], [{"end_of_log": False}])
- third_return = (["3rd line"], [{"end_of_log": True}])
- fourth_return = (["should never be read"], [{"end_of_log": True}])
+ first_return = ([[('', "1st line")]], [{}])
+ second_return = ([[('', "2nd line")]], [{"end_of_log": False}])
+ third_return = ([[('', "3rd line")]], [{"end_of_log": True}])
+ fourth_return = ([[('', "should never be read")]], [{"end_of_log":
True}])
mock_read.side_effect = [first_return, second_return, third_return,
fourth_return]
task_log_reader = TaskLogReader()
log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1,
metadata={})
- self.assertEqual(["1st line\n", "2nd line\n", "3rd line\n"],
list(log_stream))
+ self.assertEqual(["\n1st line\n", "\n2nd line\n", "\n3rd line\n"],
list(log_stream))
mock_read.assert_has_calls(
[
@@ -191,15 +195,15 @@ class TestLogView(unittest.TestCase):
@mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read")
def test_read_log_stream_should_read_each_try_in_turn(self, mock_read):
- first_return = (["try_number=1."], [{"end_of_log": True}])
- second_return = (["try_number=2."], [{"end_of_log": True}])
- third_return = (["try_number=3."], [{"end_of_log": True}])
- fourth_return = (["should never be read"], [{"end_of_log": True}])
+ first_return = ([[('', "try_number=1.")]], [{"end_of_log": True}])
+ second_return = ([[('', "try_number=2.")]], [{"end_of_log": True}])
+ third_return = ([[('', "try_number=3.")]], [{"end_of_log": True}])
+ fourth_return = ([[('', "should never be read")]], [{"end_of_log":
True}])
mock_read.side_effect = [first_return, second_return, third_return,
fourth_return]
task_log_reader = TaskLogReader()
log_stream = task_log_reader.read_log_stream(ti=self.ti,
try_number=None, metadata={})
- self.assertEqual(['try_number=1.\n', 'try_number=2.\n',
'try_number=3.\n'], list(log_stream))
+ self.assertEqual(['\ntry_number=1.\n', '\ntry_number=2.\n',
'\ntry_number=3.\n'], list(log_stream))
mock_read.assert_has_calls(
[
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 277feb3..f1ef957 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -105,7 +105,7 @@ class TestFileTaskLogHandler(unittest.TestCase):
# We should expect our log line from the callable above to appear in
# the logs we read back
self.assertRegex(
- logs[0],
+ logs[0][0][-1],
target_re,
"Logs were " + str(logs)
)
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index f276dd8..8060ab8 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -506,9 +506,9 @@ class TestAirflowBaseViews(TestBase):
val_state_color_mapping = 'const STATE_COLOR = {"failed": "red", '
\
'"null": "lightblue", "queued": "gray",
' \
'"removed": "lightgrey", "running":
"lime", ' \
- '"scheduled": "tan", "shutdown": "blue",
' \
- '"skipped": "pink", "success": "green",
' \
- '"up_for_reschedule": "turquoise", ' \
+ '"scheduled": "tan", "sensing":
"lightseagreen", ' \
+ '"shutdown": "blue", "skipped": "pink",
' \
+ '"success": "green",
"up_for_reschedule": "turquoise", ' \
'"up_for_retry": "gold",
"upstream_failed": "orange"};'
self.check_content_in_response(val_state_color_mapping, resp)
@@ -1163,9 +1163,9 @@ class TestLogView(TestBase):
self.assertEqual(response.status_code, 200)
self.assertIn('Log by attempts', response.data.decode('utf-8'))
for num in range(1, expected_num_logs_visible + 1):
- self.assertIn('try-{}'.format(num), response.data.decode('utf-8'))
- self.assertNotIn('try-0', response.data.decode('utf-8'))
- self.assertNotIn('try-{}'.format(expected_num_logs_visible + 1),
response.data.decode('utf-8'))
+ self.assertIn('log-group-{}'.format(num),
response.data.decode('utf-8'))
+ self.assertNotIn('log-group-0', response.data.decode('utf-8'))
+ self.assertNotIn('log-group-{}'.format(expected_num_logs_visible + 1),
response.data.decode('utf-8'))
def test_get_logs_with_metadata_as_download_file(self):
url_template = "get_logs_with_metadata?dag_id={}&" \
@@ -1191,10 +1191,10 @@ class TestLogView(TestBase):
def test_get_logs_with_metadata_as_download_large_file(self):
with
mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as
read_mock:
- first_return = (['1st line'], [{}])
- second_return = (['2nd line'], [{'end_of_log': False}])
- third_return = (['3rd line'], [{'end_of_log': True}])
- fourth_return = (['should never be read'], [{'end_of_log': True}])
+ first_return = ([[('default_log', '1st line')]], [{}])
+ second_return = ([[('default_log', '2nd line')]], [{'end_of_log':
False}])
+ third_return = ([[('default_log', '3rd line')]], [{'end_of_log':
True}])
+ fourth_return = ([[('default_log', 'should never be read')]],
[{'end_of_log': True}])
read_mock.side_effect = [first_return, second_return,
third_return, fourth_return]
url_template = "get_logs_with_metadata?dag_id={}&" \
"task_id={}&execution_date={}&" \
@@ -1300,7 +1300,7 @@ class TestLogView(TestBase):
response = self.client.get(url)
self.assertIn('message', response.json)
self.assertIn('metadata', response.json)
- self.assertIn('Log for testing.', response.json['message'])
+ self.assertIn('Log for testing.', response.json['message'][0][1])
self.assertEqual(200, response.status_code)
@mock.patch("airflow.www.views.TaskLogReader")