jedcunningham commented on code in PR #24908:
URL: https://github.com/apache/airflow/pull/24908#discussion_r917081354


##########
airflow/models/dataset.py:
##########
@@ -199,3 +199,86 @@ def __repr__(self):
         for attr in [x.name for x in self.__mapper__.primary_key]:
             args.append(f"{attr}={getattr(self, attr)!r}")
         return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetEvent(Base):
+    """
+    A table to store datasets events.
+
+    :param dataset_id: reference to Dataset record
+    :param extra: JSON field for arbitrary extra info
+    :param source_task_id: the task_id of the TI which updated the dataset
+    :param source_dag_id: the dag_id of the TI which updated the dataset
+    :param source_run_id: the run_id of the TI which updated the dataset
+    :param source_map_index: the map_index of the TI which updated the dataset
+
+    We use relationships instead of foreign keys so that dataset events are 
not deleted even
+    if the foreign key object is.
+    """
+
+    id = Column(Integer, primary_key=True, autoincrement=True)
+    dataset_id = Column(Integer, nullable=False)
+    extra = Column(ExtendedJSON, nullable=True)
+    source_task_id = Column(StringID(), nullable=True)
+    source_dag_id = Column(StringID(), nullable=True)
+    source_run_id = Column(StringID(), nullable=True)
+    source_map_index = Column(Integer, nullable=True, 
server_default=text("-1"))
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+    __tablename__ = "dataset_event"
+    __table_args__ = (
+        Index('idx_dataset_id_created_at', dataset_id, created_at, 
mssql_clustered=True),
+        {'sqlite_autoincrement': True},  # ensures PK values not reused
+    )
+
+    source_task_instance = relationship(
+        "TaskInstance",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
+            DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
+            DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
+            DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    source_dag_run = relationship(
+        "DagRun",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
+            DatasetEvent.source_run_id == foreign(DagRun.run_id),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    dataset = relationship(
+        Dataset,
+        primaryjoin="DatasetEvent.dataset_id == foreign(Dataset.id)",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return self.dataset_id == other.dataset_id and self.created_at == 
other.created_at
+        else:
+            return NotImplemented
+
+    def __hash__(self):
+        return hash((self.dataset_id, self.created_at))
+
+    def __repr__(self):

Review Comment:
   ```suggestion
       def __repr__(self) -> str:
   ```



##########
airflow/models/taskinstance.py:
##########
@@ -1528,6 +1528,16 @@ def _create_dataset_dag_run_queue_records(self, *, 
session):
                     continue
                 downstream_dag_ids = [x.dag_id for x in dataset.dag_references]
                 self.log.debug("downstream dag ids %s", downstream_dag_ids)
+                session.add(
+                    DatasetEvent(
+                        dataset_id=dataset.id,
+                        extra=None,

Review Comment:
   ```suggestion
   ```
   
   This is the default, no?



##########
tests/models/test_taskinstance.py:
##########
@@ -1499,10 +1499,28 @@ def test_outlet_datasets(self, create_task_instance):
         ti._run_raw_task()
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
+
+        # check that one queue record created for each dag that depends on 
dataset 1
         assert session.query(DatasetDagRunQueue.target_dag_id).filter(
             DatasetTaskRef.dag_id == dag1.dag_id, DatasetTaskRef.task_id == 
'upstream_task_1'
         ).all() == [('dag3',), ('dag4',), ('dag5',)]
 
+        # check that one event record created for dataset1 and this TI
+        assert session.query(Dataset.uri).join(DatasetEvent.dataset).filter(
+            DatasetEvent.source_task_instance == ti
+        ).one() == ('s3://dag1/output_1.txt',)
+
+        # check that no other dataset events recorded
+        assert (
+            len(
+                session.query(Dataset.uri)
+                .join(DatasetEvent.dataset)
+                .filter(DatasetEvent.source_task_instance == ti)
+                .all()
+            )
+            == 1
+        )

Review Comment:
   ```suggestion
           assert (
                   session.query(Dataset.uri)
                   .join(DatasetEvent.dataset)
                   .filter(DatasetEvent.source_task_instance == ti)
                   .count()
               ) == 1
   ```
   
   Not sure I got the formatting quite right...



##########
airflow/models/dataset.py:
##########
@@ -199,3 +199,86 @@ def __repr__(self):
         for attr in [x.name for x in self.__mapper__.primary_key]:
             args.append(f"{attr}={getattr(self, attr)!r}")
         return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetEvent(Base):
+    """
+    A table to store datasets events.
+
+    :param dataset_id: reference to Dataset record
+    :param extra: JSON field for arbitrary extra info
+    :param source_task_id: the task_id of the TI which updated the dataset
+    :param source_dag_id: the dag_id of the TI which updated the dataset
+    :param source_run_id: the run_id of the TI which updated the dataset
+    :param source_map_index: the map_index of the TI which updated the dataset
+
+    We use relationships instead of foreign keys so that dataset events are 
not deleted even
+    if the foreign key object is.
+    """
+
+    id = Column(Integer, primary_key=True, autoincrement=True)
+    dataset_id = Column(Integer, nullable=False)
+    extra = Column(ExtendedJSON, nullable=True)
+    source_task_id = Column(StringID(), nullable=True)
+    source_dag_id = Column(StringID(), nullable=True)
+    source_run_id = Column(StringID(), nullable=True)
+    source_map_index = Column(Integer, nullable=True, 
server_default=text("-1"))
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+    __tablename__ = "dataset_event"
+    __table_args__ = (
+        Index('idx_dataset_id_created_at', dataset_id, created_at, 
mssql_clustered=True),
+        {'sqlite_autoincrement': True},  # ensures PK values not reused
+    )
+
+    source_task_instance = relationship(
+        "TaskInstance",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
+            DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
+            DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
+            DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    source_dag_run = relationship(
+        "DagRun",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
+            DatasetEvent.source_run_id == foreign(DagRun.run_id),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    dataset = relationship(
+        Dataset,
+        primaryjoin="DatasetEvent.dataset_id == foreign(Dataset.id)",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return self.dataset_id == other.dataset_id and self.created_at == 
other.created_at
+        else:
+            return NotImplemented
+
+    def __hash__(self):

Review Comment:
   ```suggestion
       def __hash__(self) -> int:
   ```



##########
airflow/models/dataset.py:
##########
@@ -199,3 +199,86 @@ def __repr__(self):
         for attr in [x.name for x in self.__mapper__.primary_key]:
             args.append(f"{attr}={getattr(self, attr)!r}")
         return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetEvent(Base):
+    """
+    A table to store datasets events.
+
+    :param dataset_id: reference to Dataset record
+    :param extra: JSON field for arbitrary extra info
+    :param source_task_id: the task_id of the TI which updated the dataset
+    :param source_dag_id: the dag_id of the TI which updated the dataset
+    :param source_run_id: the run_id of the TI which updated the dataset
+    :param source_map_index: the map_index of the TI which updated the dataset
+
+    We use relationships instead of foreign keys so that dataset events are 
not deleted even
+    if the foreign key object is.
+    """
+
+    id = Column(Integer, primary_key=True, autoincrement=True)
+    dataset_id = Column(Integer, nullable=False)
+    extra = Column(ExtendedJSON, nullable=True)
+    source_task_id = Column(StringID(), nullable=True)
+    source_dag_id = Column(StringID(), nullable=True)
+    source_run_id = Column(StringID(), nullable=True)
+    source_map_index = Column(Integer, nullable=True, 
server_default=text("-1"))
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+    __tablename__ = "dataset_event"
+    __table_args__ = (
+        Index('idx_dataset_id_created_at', dataset_id, created_at, 
mssql_clustered=True),
+        {'sqlite_autoincrement': True},  # ensures PK values not reused
+    )
+
+    source_task_instance = relationship(
+        "TaskInstance",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
+            DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
+            DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
+            DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    source_dag_run = relationship(
+        "DagRun",
+        primaryjoin="""and_(
+            DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
+            DatasetEvent.source_run_id == foreign(DagRun.run_id),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    dataset = relationship(
+        Dataset,
+        primaryjoin="DatasetEvent.dataset_id == foreign(Dataset.id)",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+
+    def __eq__(self, other):

Review Comment:
   ```suggestion
       def __eq__(self, other) -> bool:
   ```
   
   If this passes mypy, eq can be a bit finicky.



##########
airflow/models/taskinstance.py:
##########
@@ -1513,10 +1513,10 @@ def _run_raw_task(
         if not test_mode:
             session.add(Log(self.state, self))
             session.merge(self)
-            self._create_dataset_dag_run_queue_records(session=session)
+            self._create_dataset_dag_run_queue_records(context=context, 
session=session)
             session.commit()
 
-    def _create_dataset_dag_run_queue_records(self, *, session):
+    def _create_dataset_dag_run_queue_records(self, *, context=None, 
session=NEW_SESSION):

Review Comment:
   ```suggestion
       def _create_dataset_dag_run_queue_records(self, *, context: Context = 
None, session: Session):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to