uranusjr commented on code in PR #47677:
URL: https://github.com/apache/airflow/pull/47677#discussion_r1992973470
##########
airflow/models/taskinstance.py:
##########
@@ -2744,104 +2733,126 @@ def _run_raw_task(
def register_asset_changes_in_db(
ti: TaskInstance,
task_outlets: list[AssetProfile],
- outlet_events: list[Any],
+ outlet_events: list[dict[str, Any]],
session: Session = NEW_SESSION,
) -> None:
- # One task only triggers one asset event for each asset with the same
extra.
- # This tuple[asset uri, extra] to sets alias names mapping is used to
find whether
- # there're assets with same uri but different extra that we need to
emit more than one asset events.
- asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] =
defaultdict(set)
- asset_name_refs: set[str] = set()
- asset_uri_refs: set[str] = set()
-
- for obj in task_outlets:
- ti.log.debug("outlet obj %s", obj)
- # Lineage can have other types of objects besides assets
- if obj.type == Asset.__name__:
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=Asset(name=obj.name, uri=obj.uri), # type: ignore
- extra=outlet_events[0]["extra"],
- session=session,
- )
- elif obj.type == AssetNameRef.__name__:
- asset_name_refs.add(obj.name) # type: ignore
- elif obj.type == AssetUriRef.__name__:
- asset_uri_refs.add(obj.uri) # type: ignore
- elif obj.type == AssetAlias.__name__:
- outlet_events = list(
- map(
- lambda event: {**event, "dest_asset_key":
AssetUniqueKey(**event["dest_asset_key"])},
- outlet_events,
- )
- )
- for asset_alias_event in outlet_events:
- asset_alias_name = asset_alias_event["source_alias_name"]
- asset_unique_key = asset_alias_event["dest_asset_key"]
- frozen_extra =
frozenset(asset_alias_event["extra"].items())
- asset_alias_names[(asset_unique_key,
frozen_extra)].add(asset_alias_name)
-
- asset_unique_keys = {key for key, _ in asset_alias_names}
- existing_aliased_assets: set[AssetUniqueKey] = {
- AssetUniqueKey.from_asset(asset_obj)
- for asset_obj in session.scalars(
+ asset_keys = {
+ AssetUniqueKey(o.name, o.uri)
+ for o in task_outlets
+ if o.type == Asset.__name__ and o.name and o.uri
+ }
+ asset_name_refs = {
+ Asset.ref(name=o.name) for o in task_outlets if o.type ==
AssetNameRef.__name__ and o.name
+ }
+ asset_uri_refs = {
+ Asset.ref(uri=o.uri) for o in task_outlets if o.type ==
AssetUriRef.__name__ and o.uri
+ }
+
+ asset_models: dict[AssetUniqueKey, AssetModel] = {
+ AssetUniqueKey.from_asset(am): am
+ for am in session.scalars(
select(AssetModel).where(
- tuple_(AssetModel.name, AssetModel.uri).in_(
- attrs.astuple(key) for key in asset_unique_keys
- )
+ AssetModel.active.has(),
+ or_(
+ tuple_(AssetModel.name,
AssetModel.uri).in_(attrs.astuple(k) for k in asset_keys),
+ AssetModel.name.in_(r.name for r in asset_name_refs),
+ AssetModel.uri.in_(r.uri for r in asset_uri_refs),
+ ),
)
)
}
- inactive_asset_unique_keys =
TaskInstance._get_inactive_asset_unique_keys(
- asset_unique_keys={key for key in asset_unique_keys if key in
existing_aliased_assets},
- session=session,
- )
- if inactive_asset_unique_keys:
- raise
AirflowInactiveAssetAddedToAssetAliasException(inactive_asset_unique_keys)
-
- if missing_assets := [
- asset_unique_key.to_asset()
- for asset_unique_key, _ in asset_alias_names
- if asset_unique_key not in existing_aliased_assets
- ]:
- asset_manager.create_assets(missing_assets, session=session)
- ti.log.warning("Created new assets for alias reference: %s",
missing_assets)
- session.flush() # Needed because we need the id for fk.
-
- for (unique_key, extra_items), alias_names in
asset_alias_names.items():
- ti.log.info(
- 'Creating event for %r through aliases "%s"',
- unique_key,
- ", ".join(alias_names),
- )
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=unique_key,
- aliases=[AssetAlias(name=name) for name in alias_names],
- extra=dict(extra_items),
- session=session,
- source_alias_names=alias_names,
- )
- # Handle events derived from references.
- asset_stmt =
select(AssetModel).where(AssetModel.name.in_(asset_name_refs),
AssetModel.active.has())
- for asset_model in session.scalars(asset_stmt):
- ti.log.info("Creating event through asset name reference %r",
asset_model.name)
+ asset_event_extras: dict[AssetUniqueKey, dict] = {
+ AssetUniqueKey(**event["dest_asset_key"]): event["extra"]
+ for event in outlet_events
+ if "source_alias_name" not in event
+ }
+
+ bad_asset_keys: set[AssetUniqueKey | AssetNameRef | AssetUriRef] =
set()
+
+ for key in asset_keys:
+ try:
+ am = asset_models[key]
+ except KeyError:
+ bad_asset_keys.add(key)
+ continue
+ ti.log.debug("register event for asset %s", am)
asset_manager.register_asset_change(
task_instance=ti,
- asset=asset_model,
- extra=outlet_events[asset_model].extra,
+ asset=am,
+ extra=asset_event_extras.get(key),
session=session,
)
- asset_stmt =
select(AssetModel).where(AssetModel.uri.in_(asset_uri_refs),
AssetModel.active.has())
- for asset_model in session.scalars(asset_stmt):
- ti.log.info("Creating event for through asset URI reference %r",
asset_model.uri)
- asset_manager.register_asset_change(
- task_instance=ti,
- asset=asset_model,
- extra=outlet_events[asset_model].extra,
+
+ if asset_name_refs:
+ asset_models_by_name = {key.name: am for key, am in
asset_models.items()}
+ asset_event_extras_by_name = {key.name: extra for key, extra in
asset_event_extras.items()}
+ for nref in asset_name_refs:
+ try:
+ am = asset_models_by_name[nref.name]
+ except KeyError:
+ bad_asset_keys.add(nref)
+ continue
+ ti.log.debug("register event for asset name ref %s", am)
+ asset_manager.register_asset_change(
+ task_instance=ti,
+ asset=am,
+ extra=asset_event_extras_by_name.get(nref.name),
+ session=session,
+ )
+ if asset_uri_refs:
+ asset_models_by_uri = {key.uri: am for key, am in
asset_models.items()}
+ asset_event_extras_by_uri = {key.uri: extra for key, extra in
asset_event_extras.items()}
+ for uref in asset_uri_refs:
+ try:
+ am = asset_models_by_uri[uref.uri]
+ except KeyError:
+ bad_asset_keys.add(uref)
+ continue
+ ti.log.debug("register event for asset uri ref %s", am)
+ asset_manager.register_asset_change(
+ task_instance=ti,
+ asset=am,
+ extra=asset_event_extras_by_uri.get(uref.uri),
+ session=session,
+ )
+
+ if bad_asset_keys:
Review Comment:
In practice 2 can’t fire unless you send a faulty request. I guess I can
move it to the last instead? (Or before 4, but it’s easier to move it to the
last since 3 and 4 are in the same if block.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]