ashb commented on a change in pull request #20286:
URL: https://github.com/apache/airflow/pull/20286#discussion_r781126481
##########
File path: airflow/models/taskinstance.py
##########
@@ -2128,6 +2135,28 @@ def set_duration(self) -> None:
self.duration = None
self.log.debug("Task Duration set to %s", self.duration)
+ @provide_session
+ def _record_task_map_for_downstreams(self, value: Any, *, session: Session
= NEW_SESSION) -> None:
+ if not self.task.has_mapped_dependants():
+ return
+ if not isinstance(value, collections.abc.Collection):
+ return # TODO: Error if the pushed value is not mappable?
Review comment:
Yeah, I think this should fail the task.
##########
File path:
airflow/migrations/versions/e655c0453f75_add_taskmap_and_map_id_on_taskinstance.py
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add TaskMap and map_index on TaskInstance.
+
+Revision ID: e655c0453f75
+Revises: 587bdf053233
+Create Date: 2021-12-13 22:59:41.052584
+"""
+
+from alembic import op
+from sqlalchemy import Column, ForeignKeyConstraint, Integer
+
+from airflow.models.base import StringID
+from airflow.utils.sqlalchemy import ExtendedJSON
+
+# Revision identifiers, used by Alembic.
+revision = "e655c0453f75"
+down_revision = "587bdf053233"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ """Add TaskMap and map_index on TaskInstance."""
+ # We need to first remove constraints on task_reschedule since they depend
on task_instance.
+ with op.batch_alter_table("task_reschedule") as batch_op:
+ batch_op.drop_constraint("task_reschedule_ti_fkey", "foreignkey")
+ batch_op.drop_index("idx_task_reschedule_dag_task_run")
+
+ # Change task_instance's primary key.
+ with op.batch_alter_table("task_instance") as batch_op:
+ # I think we always use this name for TaskInstance after 7b2661a43ba3?
+ batch_op.drop_constraint("task_instance_pkey", type_="primary")
+ batch_op.add_column(Column("map_index", Integer, nullable=False,
default=-1))
+ batch_op.create_primary_key("task_instance_pkey", ["dag_id",
"task_id", "run_id", "map_index"])
+
+ # Re-create task_reschedule's constraints.
+ with op.batch_alter_table("task_reschedule") as batch_op:
+ batch_op.add_column(Column("map_index", Integer, nullable=False,
default=-1))
+ batch_op.create_foreign_key(
+ "task_reschedule_ti_fkey",
+ "task_instance",
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ondelete="CASCADE",
+ )
+ batch_op.create_index(
+ "idx_task_reschedule_dag_task_run",
+ ["dag_id", "task_id", "run_id", "map_index"],
+ unique=False,
+ )
+
+ # Create task_map.
+ op.create_table(
+ "task_map",
+ Column("dag_id", StringID(), primary_key=True),
+ Column("task_id", StringID(), primary_key=True),
+ Column("run_id", StringID(), primary_key=True),
+ Column("map_index", Integer, primary_key=True),
Review comment:
Not that it matters, but for consistency should we have the same default
here as we do on TI's map_index column?
##########
File path: airflow/models/taskinstance.py
##########
@@ -2128,6 +2135,28 @@ def set_duration(self) -> None:
self.duration = None
self.log.debug("Task Duration set to %s", self.duration)
+ @provide_session
+ def _record_task_map_for_downstreams(self, value: Any, *, session: Session
= NEW_SESSION) -> None:
Review comment:
```suggestion
def _record_task_map_for_downstreams(self, value: Any, *, session:
Session) -> None:
```
I generally avoid using `provide_session` on internal funcs and favour being
explicit about passing the session instead.
##########
File path: airflow/models/baseoperator.py
##########
@@ -1632,6 +1632,33 @@ def defer(
def map(self, **kwargs) -> "MappedOperator":
return MappedOperator.from_operator(self, kwargs)
+ def has_mapped_dependants(self) -> bool:
+ """Whether any downstream dependencies depend on this task for
mapping."""
+ from airflow.utils.task_group import MappedTaskGroup, TaskGroup
+
+ if not self.has_dag():
+ return False
+
+ def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
+ """Recursively walk children in a task group.
+
+ This yields all direct children (including both tasks and task
+ groups), and all children of any task groups.
+ """
+ for key, child in group.children.items():
+ yield key, child
+ if isinstance(child, TaskGroup):
+ yield from _walk_group(child)
+
+ for key, child in _walk_group(self.dag.task_group):
+ if key == self.task_id:
+ continue
+ if not isinstance(child, (MappedOperator, MappedTaskGroup)):
+ continue
+ if self.task_id in child.upstream_task_ids:
+ return True
+ return False
Review comment:
Why do we walk the entire dag, rather than looking at
`self.downstream_list`?
##########
File path: airflow/models/taskinstance.py
##########
@@ -2128,6 +2135,28 @@ def set_duration(self) -> None:
self.duration = None
self.log.debug("Task Duration set to %s", self.duration)
+ @provide_session
+ def _record_task_map_for_downstreams(self, value: Any, *, session: Session
= NEW_SESSION) -> None:
+ if not self.task.has_mapped_dependants():
+ return
+ if not isinstance(value, collections.abc.Collection):
+ return # TODO: Error if the pushed value is not mappable?
+ session.query(TaskMap).filter_by(
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ run_id=self.run_id,
+ map_index=self.map_index,
+ ).delete()
Review comment:
If we are doing session.merge we don't _have to_ delete it as SQLA would
handle doing the right INSERT/UPDATE itself.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]