This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 06c382314fc Add missing attribute "name" and "group" for Asset and
"group" for AssetAlias in serialization, api and methods (#43774)
06c382314fc is described below
commit 06c382314fcb65745e33d27408d6c48fd85de21a
Author: Wei Lee <[email protected]>
AuthorDate: Mon Dec 2 18:34:42 2024 +0800
Add missing attribute "name" and "group" for Asset and "group" for
AssetAlias in serialization, api and methods (#43774)
* test(tests/www/views/test_views_grid): extend Asset test cases to include
both uri and name
* test(utils/test_json): extend Asset test cases to include both uri and
name
* test(timetables/test_assets_timetable): extend Asset test cases to
include both uri and name
* test(listeners/test_asset_listener): extend Asset test cases to include
both uri and name
* test(jobs/test_scheduler_job): extend Asset test cases to include both
uri and name
* test(providers/openlineage): extend Asset test cases to include both uri
and name
* test(decorators/test_python): extend Asset test cases to include both uri
and name
* test(models/test_dag): extend asset test cases to cover name, uri, group
* test(api_connexsion/schemas/dag_run): extend asset test cases to cover
name, uri, group
* test(serialization/serialized_objects): extend asset test cases to cover
name, uri, group and asset alias test cases to cover name and group
* test(serialization/dag_serialization): extend asset test cases to cover
name, uri, group
* test(models/dag): extend asset test cases to cover name, uri, group
* test(serialization/serde): extend asset test cases to cover name, uri,
group
* test(api_connexion/schemas/asset): extend asset test cases to cover name,
uri, group
* test(api_connexion/schemas/asset): extend asset alias test cases to cover
name, group
* test(api_connexsion/schemas/dag): extend asset test cases to cover name,
uri, group
* test(api_connexsion/schemas/dag_run): extend asset test cases to cover
name, uri, group
* test(dags/test_assets): extend asset test cases to cover name, uri, group
* test(dags/test_only_empty_tasks): extend asset test cases to cover name,
uri, group
* test(api_fastapi): extend asset test cases to cover name, uri, group
* test(assets/manager): extend asset test cases to cover name, uri, group
* test(task_sdk/assets): extend asset test cases to cover name, uri, group
* test(api_connexion/endpoints/asset): extend asset test cases to cover
name, uri, group
* test: add missing session
* test(www/views/asset): extend asset test cases to cover name, uri, group
* test(models/seraialized_dag): extend asset test cases to cover name, uri,
group
* test(lineage/hook): extend asset test cases to cover name, uri, group
* test(io/path): extend asset test cases to cover name, uri, group
* test(jobs): enhance
test_activate_referenced_assets_with_no_existing_warning to cover extra edge
case
* fix(serialization): serialize both name, uri and group for Asset
* fix(assets): extend Asset as_expression methods to include name, group
fields (also AssetAlias group field)
* fix(serialization/serialized_objects): fix missing AssetAlias.group
serialization
* fix(serialization): change dependency_id to use name instead of uri
* feat(api_connexion/schemas/asset): add name, group to asset schema and
group to asset alias schema
* feat(assets/manager): filter asset by name, uri, group instead of uri only
* style(assets/manager): rename argument asset in
_add_asset_alias_association as asset_model
* fix(asset): use name to evalute instead of uri
* fix(api_connexion/endpoints/asset): fix how asset event is fetch in
create asset event
* fix(api_fastapi/asset): fix how asset event is fetch in create asset event
* fix(lineage/hook): extend asset realted methods to include name and group
* fix(task_sdk/asset): change iter_assets to return ((name, uri), obj)
instead of (uri, obj)
* fix(fastapi/asset): add missing group column to asset alias schema
* build: build autogen ts files
* feat(lineage/hook): make create_asset keyword only
* docs(newsfragments): add 43774.significant.rst
* refactor(task_sdk/asset): add from_asset_alias to AssetAliasCondition to
remove duplicate code
* refactor(task_sdk/asset): add AssetUniqueKey.from_asset to reduce
duplicate code
* Revert "fix(asset): use name to evalute instead of uri"
This reverts commit e812b8ada59e925beeb52c8ddb0d14b0dfec1abf.
---
airflow/api_connexion/endpoints/asset_endpoint.py | 8 +-
airflow/api_connexion/schemas/asset_schema.py | 3 +
airflow/api_fastapi/core_api/datamodels/assets.py | 1 +
.../api_fastapi/core_api/openapi/v1-generated.yaml | 4 +
.../api_fastapi/core_api/routes/public/assets.py | 7 +-
airflow/assets/manager.py | 12 +-
airflow/lineage/hook.py | 40 ++-
airflow/serialization/serialized_objects.py | 66 ++++-
airflow/timetables/base.py | 4 +-
airflow/timetables/simple.py | 2 +-
airflow/ui/openapi-gen/requests/schemas.gen.ts | 6 +-
airflow/ui/openapi-gen/requests/types.gen.ts | 1 +
newsfragments/43774.significant.rst | 22 ++
providers/tests/openlineage/plugins/test_utils.py | 30 +-
.../src/airflow/sdk/definitions/asset/__init__.py | 44 ++-
task_sdk/tests/defintions/test_asset.py | 117 ++++----
.../api_connexion/endpoints/test_asset_endpoint.py | 10 +
.../endpoints/test_dag_run_endpoint.py | 2 +-
tests/api_connexion/schemas/test_asset_schema.py | 20 +-
tests/api_connexion/schemas/test_dag_schema.py | 4 +-
.../core_api/routes/public/test_assets.py | 4 +-
.../api_fastapi/core_api/routes/ui/test_assets.py | 20 +-
tests/assets/test_manager.py | 24 +-
tests/dags/test_assets.py | 4 +-
tests/dags/test_only_empty_tasks.py | 4 +-
tests/decorators/test_python.py | 3 +-
tests/io/test_path.py | 2 +-
tests/jobs/test_scheduler_job.py | 39 +--
tests/lineage/test_hook.py | 49 +++-
tests/listeners/test_asset_listener.py | 10 +-
tests/models/test_dag.py | 86 ++++--
tests/models/test_serialized_dag.py | 12 +-
tests/serialization/test_dag_serialization.py | 304 ++++++++++++++++-----
tests/serialization/test_serde.py | 2 +-
tests/serialization/test_serialized_objects.py | 58 +++-
tests/timetables/test_assets_timetable.py | 22 +-
tests/utils/test_json.py | 2 +-
tests/www/views/test_views_asset.py | 11 +-
tests/www/views/test_views_grid.py | 13 +-
39 files changed, 774 insertions(+), 298 deletions(-)
diff --git a/airflow/api_connexion/endpoints/asset_endpoint.py
b/airflow/api_connexion/endpoints/asset_endpoint.py
index 1bda4fdb2a2..64930b12494 100644
--- a/airflow/api_connexion/endpoints/asset_endpoint.py
+++ b/airflow/api_connexion/endpoints/asset_endpoint.py
@@ -45,7 +45,6 @@ from airflow.api_connexion.schemas.asset_schema import (
)
from airflow.assets.manager import asset_manager
from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
-from airflow.sdk.definitions.asset import Asset
from airflow.utils import timezone
from airflow.utils.api_migration import mark_fastapi_migration_done
from airflow.utils.db import get_query_count
@@ -341,15 +340,16 @@ def create_asset_event(session: Session = NEW_SESSION) ->
APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err))
+ # TODO: handle name
uri = json_body["asset_uri"]
- asset = session.scalar(select(AssetModel).where(AssetModel.uri ==
uri).limit(1))
- if not asset:
+ asset_model = session.scalar(select(AssetModel).where(AssetModel.uri ==
uri).limit(1))
+ if not asset_model:
raise NotFound(title="Asset not found", detail=f"Asset with uri:
'{uri}' not found")
timestamp = timezone.utcnow()
extra = json_body.get("extra", {})
extra["from_rest_api"] = True
asset_event = asset_manager.register_asset_change(
- asset=Asset(uri=uri),
+ asset=asset_model.to_public(),
timestamp=timestamp,
extra=extra,
session=session,
diff --git a/airflow/api_connexion/schemas/asset_schema.py
b/airflow/api_connexion/schemas/asset_schema.py
index e83c4f1b427..078ebb3e758 100644
--- a/airflow/api_connexion/schemas/asset_schema.py
+++ b/airflow/api_connexion/schemas/asset_schema.py
@@ -70,6 +70,7 @@ class AssetAliasSchema(SQLAlchemySchema):
id = auto_field()
name = auto_field()
+ group = auto_field()
class AssetSchema(SQLAlchemySchema):
@@ -82,6 +83,8 @@ class AssetSchema(SQLAlchemySchema):
id = auto_field()
uri = auto_field()
+ name = auto_field()
+ group = auto_field()
extra = JsonObjectField()
created_at = auto_field()
updated_at = auto_field()
diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py
b/airflow/api_fastapi/core_api/datamodels/assets.py
index 638ee1cba6e..72bba200fab 100644
--- a/airflow/api_fastapi/core_api/datamodels/assets.py
+++ b/airflow/api_fastapi/core_api/datamodels/assets.py
@@ -47,6 +47,7 @@ class AssetAliasSchema(BaseModel):
id: int
name: str
+ group: str
class AssetResponse(BaseModel):
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index a331a637c2e..20c450cf0a2 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -5786,10 +5786,14 @@ components:
name:
type: string
title: Name
+ group:
+ type: string
+ title: Group
type: object
required:
- id
- name
+ - group
title: AssetAliasSchema
description: Asset alias serializer for assets.
AssetCollectionResponse:
diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py
b/airflow/api_fastapi/core_api/routes/public/assets.py
index 70bf5b047bb..db3fa61767a 100644
--- a/airflow/api_fastapi/core_api/routes/public/assets.py
+++ b/airflow/api_fastapi/core_api/routes/public/assets.py
@@ -51,7 +51,6 @@ from airflow.api_fastapi.core_api.datamodels.assets import (
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
from airflow.assets.manager import asset_manager
from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
-from airflow.sdk.definitions.asset import Asset
from airflow.utils import timezone
assets_router = AirflowRouter(tags=["Asset"])
@@ -171,13 +170,13 @@ def create_asset_event(
session: SessionDep,
) -> AssetEventResponse:
"""Create asset events."""
- asset = session.scalar(select(AssetModel).where(AssetModel.uri ==
body.uri).limit(1))
- if not asset:
+ asset_model = session.scalar(select(AssetModel).where(AssetModel.uri ==
body.uri).limit(1))
+ if not asset_model:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri:
`{body.uri}` was not found")
timestamp = timezone.utcnow()
assets_event = asset_manager.register_asset_change(
- asset=Asset(uri=body.uri),
+ asset=asset_model.to_public(),
timestamp=timestamp,
extra=body.extra,
session=session,
diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py
index 40bc97b8134..364d01607e5 100644
--- a/airflow/assets/manager.py
+++ b/airflow/assets/manager.py
@@ -86,16 +86,16 @@ class AssetManager(LoggingMixin):
def _add_asset_alias_association(
cls,
alias_names: Collection[str],
- asset: AssetModel,
+ asset_model: AssetModel,
*,
session: Session,
) -> None:
- already_related = {m.name for m in asset.aliases}
+ already_related = {m.name for m in asset_model.aliases}
existing_aliases = {
m.name: m
for m in
session.scalars(select(AssetAliasModel).where(AssetAliasModel.name.in_(alias_names)))
}
- asset.aliases.extend(
+ asset_model.aliases.extend(
existing_aliases.get(name, AssetAliasModel(name=name))
for name in alias_names
if name not in already_related
@@ -121,7 +121,7 @@ class AssetManager(LoggingMixin):
"""
asset_model = session.scalar(
select(AssetModel)
- .where(AssetModel.uri == asset.uri)
+ .where(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
.options(
joinedload(AssetModel.aliases),
joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag),
@@ -131,7 +131,9 @@ class AssetManager(LoggingMixin):
cls.logger().warning("AssetModel %s not found", asset)
return None
- cls._add_asset_alias_association({alias.name for alias in aliases},
asset_model, session=session)
+ cls._add_asset_alias_association(
+ alias_names={alias.name for alias in aliases},
asset_model=asset_model, session=session
+ )
event_kwargs = {
"asset_id": asset_model.id,
diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 9e5f8f66482..62a2c7a5493 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -95,24 +95,40 @@ class HookLineageCollector(LoggingMixin):
return f"{asset.uri}_{extra_hash}_{id(context)}"
def create_asset(
- self, scheme: str | None, uri: str | None, asset_kwargs: dict | None,
asset_extra: dict | None
+ self,
+ *,
+ scheme: str | None = None,
+ uri: str | None = None,
+ name: str | None = None,
+ group: str | None = None,
+ asset_kwargs: dict | None = None,
+ asset_extra: dict | None = None,
) -> Asset | None:
"""
Create an asset instance using the provided parameters.
This method attempts to create an asset instance using the given
parameters.
- It first checks if a URI is provided and falls back to using the
default asset factory
- with the given URI if no other information is available.
+ It first checks if a URI or a name is provided and falls back to using
the default asset factory
+ with the given URI or name if no other information is available.
- If a scheme is provided but no URI, it attempts to find an asset
factory that matches
+ If a scheme is provided but no URI or name, it attempts to find an
asset factory that matches
the given scheme. If no such factory is found, it logs an error
message and returns None.
If asset_kwargs is provided, it is used to pass additional parameters
to the asset
factory. The asset_extra parameter is also passed to the factory as an
``extra`` parameter.
"""
- if uri:
+ if uri or name:
# Fallback to default factory using the provided URI
- return Asset(uri=uri, extra=asset_extra)
+ kwargs: dict[str, str | dict] = {}
+ if uri:
+ kwargs["uri"] = uri
+ if name:
+ kwargs["name"] = name
+ if group:
+ kwargs["group"] = group
+ if asset_extra:
+ kwargs["extra"] = asset_extra
+ return Asset(**kwargs) # type: ignore[call-overload]
if not scheme:
self.log.debug(
@@ -137,11 +153,15 @@ class HookLineageCollector(LoggingMixin):
context: LineageContext,
scheme: str | None = None,
uri: str | None = None,
+ name: str | None = None,
+ group: str | None = None,
asset_kwargs: dict | None = None,
asset_extra: dict | None = None,
):
"""Add the input asset and its corresponding hook execution context to
the collector."""
- asset = self.create_asset(scheme=scheme, uri=uri,
asset_kwargs=asset_kwargs, asset_extra=asset_extra)
+ asset = self.create_asset(
+ scheme=scheme, uri=uri, name=name, group=group,
asset_kwargs=asset_kwargs, asset_extra=asset_extra
+ )
if asset:
key = self._generate_key(asset, context)
if key not in self._inputs:
@@ -153,11 +173,15 @@ class HookLineageCollector(LoggingMixin):
context: LineageContext,
scheme: str | None = None,
uri: str | None = None,
+ name: str | None = None,
+ group: str | None = None,
asset_kwargs: dict | None = None,
asset_extra: dict | None = None,
):
"""Add the output asset and its corresponding hook execution context
to the collector."""
- asset = self.create_asset(scheme=scheme, uri=uri,
asset_kwargs=asset_kwargs, asset_extra=asset_extra)
+ asset = self.create_asset(
+ scheme=scheme, uri=uri, name=name, group=group,
asset_kwargs=asset_kwargs, asset_extra=asset_extra
+ )
if asset:
key = self._generate_key(asset, context)
if key not in self._outputs:
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 1a13430e2fc..f78a2b78b88 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -44,7 +44,11 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
-from airflow.models.expandinput import EXPAND_INPUT_EMPTY,
create_expand_input, get_map_type_key
+from airflow.models.expandinput import (
+ EXPAND_INPUT_EMPTY,
+ create_expand_input,
+ get_map_type_key,
+)
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
@@ -213,7 +217,9 @@ def _get_registered_timetable(importable_string: str) ->
type[Timetable] | None:
return None
-def _get_registered_priority_weight_strategy(importable_string: str) ->
type[PriorityWeightStrategy] | None:
+def _get_registered_priority_weight_strategy(
+ importable_string: str,
+) -> type[PriorityWeightStrategy] | None:
from airflow import plugins_manager
if importable_string in airflow_priority_weight_strategies:
@@ -256,13 +262,25 @@ def encode_asset_condition(var: BaseAsset) -> dict[str,
Any]:
:meta private:
"""
if isinstance(var, Asset):
- return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri,
"extra": var.extra}
+ return {
+ "__type": DAT.ASSET,
+ "name": var.name,
+ "uri": var.uri,
+ "group": var.group,
+ "extra": var.extra,
+ }
if isinstance(var, AssetAlias):
- return {"__type": DAT.ASSET_ALIAS, "name": var.name}
+ return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group":
var.group}
if isinstance(var, AssetAll):
- return {"__type": DAT.ASSET_ALL, "objects": [encode_asset_condition(x)
for x in var.objects]}
+ return {
+ "__type": DAT.ASSET_ALL,
+ "objects": [encode_asset_condition(x) for x in var.objects],
+ }
if isinstance(var, AssetAny):
- return {"__type": DAT.ASSET_ANY, "objects": [encode_asset_condition(x)
for x in var.objects]}
+ return {
+ "__type": DAT.ASSET_ANY,
+ "objects": [encode_asset_condition(x) for x in var.objects],
+ }
raise ValueError(f"serialization not implemented for
{type(var).__name__!r}")
@@ -274,13 +292,13 @@ def decode_asset_condition(var: dict[str, Any]) ->
BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
- return Asset(uri=var["uri"], name=var["name"], extra=var["extra"])
+ return Asset(name=var["name"], uri=var["uri"], group=var["group"],
extra=var["extra"])
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
return AssetAny(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ALIAS:
- return AssetAlias(name=var["name"])
+ return AssetAlias(name=var["name"], group=var["group"])
raise ValueError(f"deserialization not implemented for DAT {dat!r}")
@@ -586,7 +604,9 @@ class BaseSerialization:
@classmethod
def serialize_to_json(
- cls, object_to_serialize: BaseOperator | MappedOperator | DAG,
decorated_fields: set
+ cls,
+ object_to_serialize: BaseOperator | MappedOperator | DAG,
+ decorated_fields: set,
) -> dict[str, Any]:
"""Serialize an object to JSON."""
serialized_object: dict[str, Any] = {}
@@ -653,7 +673,11 @@ class BaseSerialization:
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, OutletEventAccessors):
return cls._encode(
- cls.serialize(var._dict, strict=strict,
use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
+ cls.serialize(
+ var._dict, # type: ignore[attr-defined]
+ strict=strict,
+ use_pydantic_models=use_pydantic_models,
+ ),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, OutletEventAccessor):
@@ -696,7 +720,11 @@ class BaseSerialization:
elif isinstance(var, (KeyError, AttributeError)):
return cls._encode(
cls.serialize(
- {"exc_cls_name": var.__class__.__name__, "args":
[var.args], "kwargs": {}},
+ {
+ "exc_cls_name": var.__class__.__name__,
+ "args": [var.args],
+ "kwargs": {},
+ },
use_pydantic_models=use_pydantic_models,
strict=strict,
),
@@ -704,7 +732,11 @@ class BaseSerialization:
)
elif isinstance(var, BaseTrigger):
return cls._encode(
- cls.serialize(var.serialize(),
use_pydantic_models=use_pydantic_models, strict=strict),
+ cls.serialize(
+ var.serialize(),
+ use_pydantic_models=use_pydantic_models,
+ strict=strict,
+ ),
type_=DAT.BASE_TRIGGER,
)
elif callable(var):
@@ -1065,11 +1097,11 @@ class DependencyDetector:
source=task.dag_id,
target="asset",
dependency_type="asset",
- dependency_id=obj.uri,
+ dependency_id=obj.name,
)
)
elif isinstance(obj, AssetAlias):
- cond = AssetAliasCondition(obj.name)
+ cond = AssetAliasCondition(name=obj.name, group=obj.group)
deps.extend(cond.iter_dag_dependencies(source=task.dag_id,
target=""))
return deps
@@ -1298,7 +1330,11 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
# The case for "If OperatorLinks are defined in the operator that
is being Serialized"
# is handled in the deserialization loop where it matches k ==
"_operator_extra_links"
if op_extra_links_from_plugin and "_operator_extra_links" not in
encoded_op:
- setattr(op, "operator_extra_links",
list(op_extra_links_from_plugin.values()))
+ setattr(
+ op,
+ "operator_extra_links",
+ list(op_extra_links_from_plugin.values()),
+ )
for k, v in encoded_op.items():
# python_callable_name only serves to detect function name changes
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index 60b1c141209..b80a6323a8c 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any, NamedTuple
-from airflow.sdk.definitions.asset import BaseAsset
+from airflow.sdk.definitions.asset import AssetUniqueKey, BaseAsset
from airflow.typing_compat import Protocol, runtime_checkable
if TYPE_CHECKING:
@@ -55,7 +55,7 @@ class _NullAsset(BaseAsset):
def evaluate(self, statuses: dict[str, bool]) -> bool:
return False
- def iter_assets(self) -> Iterator[tuple[str, Asset]]:
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py
index f282c7fe67f..57eec884b55 100644
--- a/airflow/timetables/simple.py
+++ b/airflow/timetables/simple.py
@@ -170,7 +170,7 @@ class AssetTriggeredTimetable(_TrivialTimetable):
super().__init__()
self.asset_condition = assets
if isinstance(self.asset_condition, AssetAlias):
- self.asset_condition =
AssetAliasCondition(self.asset_condition.name)
+ self.asset_condition =
AssetAliasCondition.from_asset_alias(self.asset_condition)
if not next(self.asset_condition.iter_assets(), False):
self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 503910d75ad..ae2eeaf1687 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -99,9 +99,13 @@ export const $AssetAliasSchema = {
type: "string",
title: "Name",
},
+ group: {
+ type: "string",
+ title: "Group",
+ },
},
type: "object",
- required: ["id", "name"],
+ required: ["id", "name", "group"],
title: "AssetAliasSchema",
description: "Asset alias serializer for assets.",
} as const;
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 48861f2bf42..1f1838d0870 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -27,6 +27,7 @@ export type AppBuilderViewResponse = {
export type AssetAliasSchema = {
id: number;
name: string;
+ group: string;
};
/**
diff --git a/newsfragments/43774.significant.rst
b/newsfragments/43774.significant.rst
new file mode 100644
index 00000000000..b716e1fc83f
--- /dev/null
+++ b/newsfragments/43774.significant.rst
@@ -0,0 +1,22 @@
+``HookLineageCollector.create_asset`` now accept only keyword arguments
+
+To provider AIP-74 support, new arguments "name" and "group" are added to
``HookLineageCollector.create_asset``.
+For easier change in the future, this function now takes only keyword
arguments.
+
+.. Check the type of change that applies to this change
+
+* Types of change
+
+ * [ ] DAG changes
+ * [ ] Config changes
+ * [ ] API changes
+ * [ ] CLI changes
+ * [x] Behaviour changes
+ * [ ] Plugin changes
+ * [ ] Dependency change
+
+.. List the migration rules needed for this change (see
https://github.com/apache/airflow/issues/41641)
+
+* Migrations rules needed
+
+ * Calling ``HookLineageCollector.create_asset`` with positional argument
should raise an error
diff --git a/providers/tests/openlineage/plugins/test_utils.py
b/providers/tests/openlineage/plugins/test_utils.py
index e84fac11865..3d41e87cf01 100644
--- a/providers/tests/openlineage/plugins/test_utils.py
+++ b/providers/tests/openlineage/plugins/test_utils.py
@@ -334,9 +334,9 @@ def test_serialize_timetable():
from airflow.timetables.simple import AssetTriggeredTimetable
asset = AssetAny(
- Asset("2"),
- AssetAlias("example-alias"),
- Asset("3"),
+ Asset(name="2", uri="test://2", group="test-group"),
+ AssetAlias(name="example-alias", group="test-group"),
+ Asset(name="3", uri="test://3", group="test-group"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)
dag = MagicMock()
@@ -347,14 +347,32 @@ def test_serialize_timetable():
"asset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
- {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "2",
"uri": "2"},
+ {
+ "__type": DagAttributeTypes.ASSET,
+ "extra": {},
+ "uri": "test://2/",
+ "name": "2",
+ "group": "test-group",
+ },
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
- {"__type": DagAttributeTypes.ASSET, "extra": {}, "name": "3",
"uri": "3"},
+ {
+ "__type": DagAttributeTypes.ASSET,
+ "extra": {},
+ "uri": "test://3/",
+ "name": "3",
+ "group": "test-group",
+ },
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
- {"__type": DagAttributeTypes.ASSET, "extra": {},
"name": "4", "uri": "4"},
+ {
+ "__type": DagAttributeTypes.ASSET,
+ "extra": {},
+ "uri": "4",
+ "name": "4",
+ "group": "asset",
+ },
],
},
],
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index 620c89473ff..554dd9bb4a7 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -27,6 +27,7 @@ from typing import (
Any,
Callable,
ClassVar,
+ NamedTuple,
cast,
overload,
)
@@ -62,6 +63,15 @@ __all__ = [
log = logging.getLogger(__name__)
+class AssetUniqueKey(NamedTuple):
+ name: str
+ uri: str
+
+ @staticmethod
+ def from_asset(asset: Asset) -> AssetUniqueKey:
+ return AssetUniqueKey(name=asset.name, uri=asset.uri)
+
+
def normalize_noop(parts: SplitResult) -> SplitResult:
"""
Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.
@@ -202,7 +212,7 @@ class BaseAsset:
def evaluate(self, statuses: dict[str, bool]) -> bool:
raise NotImplementedError
- def iter_assets(self) -> Iterator[tuple[str, Asset]]:
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
raise NotImplementedError
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
@@ -349,10 +359,10 @@ class Asset(os.PathLike, BaseAsset):
:meta private:
"""
- return self.uri
+ return {"asset": {"uri": self.uri, "name": self.name, "group":
self.group}}
- def iter_assets(self) -> Iterator[tuple[str, Asset]]:
- yield self.uri, self
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
+ yield AssetUniqueKey.from_asset(self), self
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
return iter(())
@@ -370,7 +380,7 @@ class Asset(os.PathLike, BaseAsset):
source=source or "asset",
target=target or "asset",
dependency_type="asset",
- dependency_id=self.uri,
+ dependency_id=self.name,
)
@@ -400,7 +410,7 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="",
validator=_validate_identifier)
- def iter_assets(self) -> Iterator[tuple[str, Asset]]:
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
@@ -438,13 +448,14 @@ class _AssetBooleanCondition(BaseAsset):
raise TypeError("expect asset expressions in condition")
self.objects = [
- AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else
obj for obj in objects
+ AssetAliasCondition.from_asset_alias(obj) if isinstance(obj,
AssetAlias) else obj
+ for obj in objects
]
def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in
self.objects)
- def iter_assets(self) -> Iterator[tuple[str, Asset]]:
+ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
seen = set() # We want to keep the first instance.
for o in self.objects:
for k, v in o.iter_assets():
@@ -513,8 +524,9 @@ class AssetAliasCondition(AssetAny):
:meta private:
"""
- def __init__(self, name: str) -> None:
+ def __init__(self, name: str, group: str) -> None:
self.name = name
+ self.group = group
self.objects = expand_alias_to_assets(name)
def __repr__(self) -> str:
@@ -526,7 +538,7 @@ class AssetAliasCondition(AssetAny):
:meta private:
"""
- return {"alias": self.name}
+ return {"alias": {"name": self.name, "group": self.group}}
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
yield self.name, AssetAlias(self.name)
@@ -540,18 +552,18 @@ class AssetAliasCondition(AssetAny):
if self.objects:
for obj in self.objects:
asset = cast(Asset, obj)
- uri = asset.uri
+ asset_name = asset.name
# asset
yield DagDependency(
source=f"asset-alias:{self.name}" if source else "asset",
target="asset" if source else f"asset-alias:{self.name}",
dependency_type="asset",
- dependency_id=uri,
+ dependency_id=asset_name,
)
# asset alias
yield DagDependency(
- source=source or f"asset:{uri}",
- target=target or f"asset:{uri}",
+ source=source or f"asset:{asset_name}",
+ target=target or f"asset:{asset_name}",
dependency_type="asset-alias",
dependency_id=self.name,
)
@@ -563,6 +575,10 @@ class AssetAliasCondition(AssetAny):
dependency_id=self.name,
)
+ @staticmethod
+ def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition:
+ return AssetAliasCondition(name=asset_alias.name,
group=asset_alias.group)
+
class AssetAll(_AssetBooleanCondition):
"""Use to combine assets schedule references in an "or" relationship."""
diff --git a/task_sdk/tests/defintions/test_asset.py
b/task_sdk/tests/defintions/test_asset.py
index ef602ea5a22..d9aa6305f57 100644
--- a/task_sdk/tests/defintions/test_asset.py
+++ b/task_sdk/tests/defintions/test_asset.py
@@ -170,18 +170,18 @@ def test_asset_logic_operations():
def test_asset_iter_assets():
- assert list(asset1.iter_assets()) == [("s3://bucket1/data1", asset1)]
+ assert list(asset1.iter_assets()) == [(("asset-1", "s3://bucket1/data1"),
asset1)]
@pytest.mark.db_test
def test_asset_iter_asset_aliases():
base_asset = AssetAll(
- AssetAlias("example-alias-1"),
+ AssetAlias(name="example-alias-1"),
Asset("1"),
AssetAny(
- Asset("2"),
+ Asset(name="2", uri="test://asset1"),
AssetAlias("example-alias-2"),
- Asset("3"),
+ Asset(name="3"),
AssetAll(AssetAlias("example-alias-3"), Asset("4"),
AssetAlias("example-alias-4")),
),
AssetAll(AssetAlias("example-alias-5"), Asset("5")),
@@ -225,8 +225,14 @@ def test_assset_boolean_condition_evaluate_iter():
# Testing iter_assets indirectly through the subclasses
assets_any = dict(any_condition.iter_assets())
assets_all = dict(all_condition.iter_assets())
- assert assets_any == {"s3://bucket1/data1": asset1, "s3://bucket2/data2":
asset2}
- assert assets_all == {"s3://bucket1/data1": asset1, "s3://bucket2/data2":
asset2}
+ assert assets_any == {
+ ("asset-1", "s3://bucket1/data1"): asset1,
+ ("asset-2", "s3://bucket2/data2"): asset2,
+ }
+ assert assets_all == {
+ ("asset-1", "s3://bucket1/data1"): asset1,
+ ("asset-2", "s3://bucket2/data2"): asset2,
+ }
@pytest.mark.parametrize(
@@ -254,7 +260,7 @@ def test_assset_boolean_condition_evaluate_iter():
)
def test_asset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
class_ = AssetAny if scenario == "any" else AssetAll
- assets = [Asset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in
range(123, 126)]
condition = class_(*assets)
statuses = {asset.uri: status for asset, status in zip(assets, inputs)}
@@ -274,31 +280,31 @@ def
test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario,
(
(False, True, True),
False,
- ), # AssetAll requires all conditions to be True, but d1 is False
+ ), # AssetAll requires all conditions to be True, but asset1 is False
((True, True, True), True), # All conditions are True
(
(True, False, True),
True,
- ), # d1 is True, and AssetAny condition (d2 or d3 being True) is met
+ ), # asset1 is True, and AssetAny condition (asset2 or asset3 being
True) is met
(
(True, False, False),
False,
- ), # d1 is True, but neither d2 nor d3 meet the AssetAny condition
+ ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny
condition
],
)
def test_nested_asset_conditions_with_serialization(status_values,
expected_evaluation):
# Define assets
- d1 = Asset(uri="s3://abc/123")
- d2 = Asset(uri="s3://abc/124")
- d3 = Asset(uri="s3://abc/125")
+ asset1 = Asset(uri="s3://abc/123")
+ asset2 = Asset(uri="s3://abc/124")
+ asset3 = Asset(uri="s3://abc/125")
- # Create a nested condition: AssetAll with d1 and AssetAny with d2 and d3
- nested_condition = AssetAll(d1, AssetAny(d2, d3))
+ # Create a nested condition: AssetAll with asset1 and AssetAny with asset2
and asset3
+ nested_condition = AssetAll(asset1, AssetAny(asset2, asset3))
statuses = {
- d1.uri: status_values[0],
- d2.uri: status_values[1],
- d3.uri: status_values[2],
+ asset1.uri: status_values[0],
+ asset2.uri: status_values[1],
+ asset3.uri: status_values[2],
}
assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
@@ -314,7 +320,7 @@ def
test_nested_asset_conditions_with_serialization(status_values, expected_eval
@pytest.fixture
def create_test_assets(session):
"""Fixture to create test assets and corresponding models."""
- assets = [Asset(uri=f"hello{i}") for i in range(1, 3)]
+ assets = [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in
range(1, 3)]
for asset in assets:
session.add(AssetModel(uri=asset.uri))
session.commit()
@@ -380,17 +386,17 @@ def test_asset_dag_run_queue_processing(session,
clear_assets, dag_maker, create
@pytest.mark.usefixtures("clear_assets")
def test_dag_with_complex_asset_condition(session, dag_maker):
# Create Asset instances
- d1 = Asset(uri="hello1")
- d2 = Asset(uri="hello2")
+ asset1 = Asset(uri="test://asset1", name="hello1")
+ asset2 = Asset(uri="test://asset2", name="hello2")
# Create and add AssetModel instances to the session
- am1 = AssetModel(uri=d1.uri)
- am2 = AssetModel(uri=d2.uri)
+ am1 = AssetModel(uri=asset1.uri, name=asset1.name, group="asset")
+ am2 = AssetModel(uri=asset2.uri, name=asset2.name, group="asset")
session.add_all([am1, am2])
session.commit()
# Setup a DAG with complex asset triggers (AssetAny with AssetAll)
- with dag_maker(schedule=AssetAny(d1, AssetAll(d2, d1))) as dag:
+ with dag_maker(schedule=AssetAny(asset1, AssetAll(asset2, asset1))) as dag:
EmptyOperator(task_id="hello")
assert isinstance(
@@ -442,11 +448,11 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool:
return False
-asset1 = Asset(uri="s3://bucket1/data1")
-asset2 = Asset(uri="s3://bucket2/data2")
-asset3 = Asset(uri="s3://bucket3/data3")
-asset4 = Asset(uri="s3://bucket4/data4")
-asset5 = Asset(uri="s3://bucket5/data5")
+asset1 = Asset(uri="s3://bucket1/data1", name="asset-1")
+asset2 = Asset(uri="s3://bucket2/data2", name="asset-2")
+asset3 = Asset(uri="s3://bucket3/data3", name="asset-3")
+asset4 = Asset(uri="s3://bucket4/data4", name="asset-4")
+asset5 = Asset(uri="s3://bucket5/data5", name="asset-5")
test_cases = [
(lambda: asset1, asset1),
@@ -579,21 +585,27 @@ def test_normalize_uri_valid_uri():
@pytest.mark.usefixtures("clear_assets")
class TestAssetAliasCondition:
@pytest.fixture
- def asset_1(self, session):
+ def asset_model(self, session):
"""Example asset links to asset alias resolved_asset_alias_2."""
- asset_uri = "test_uri"
- asset_1 = AssetModel(id=1, uri=asset_uri)
-
- session.add(asset_1)
+ asset_model = AssetModel(
+ id=1,
+ uri="test://asset1/",
+ name="test_name",
+ group="asset",
+ )
+
+ session.add(asset_model)
session.commit()
- return asset_1
+ return asset_model
@pytest.fixture
def asset_alias_1(self, session):
"""Example asset alias links to no assets."""
- alias_name = "test_name"
- asset_alias_model = AssetAliasModel(name=alias_name)
+ asset_alias_model = AssetAliasModel(
+ name="test_name",
+ group="test",
+ )
session.add(asset_alias_model)
session.commit()
@@ -601,35 +613,34 @@ class TestAssetAliasCondition:
return asset_alias_model
@pytest.fixture
- def resolved_asset_alias_2(self, session, asset_1):
+ def resolved_asset_alias_2(self, session, asset_model):
"""Example asset alias links to asset asset_alias_1."""
- asset_name = "test_name_2"
- asset_alias_2 = AssetAliasModel(name=asset_name)
- asset_alias_2.assets.append(asset_1)
+ asset_alias_2 = AssetAliasModel(name="test_name_2")
+ asset_alias_2.assets.append(asset_model)
session.add(asset_alias_2)
session.commit()
return asset_alias_2
- def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2):
- cond = AssetAliasCondition(name=asset_alias_1.name)
+ def test_init(self, asset_alias_1, asset_model, resolved_asset_alias_2):
+ cond = AssetAliasCondition.from_asset_alias(asset_alias_1)
assert cond.objects == []
- cond = AssetAliasCondition(name=resolved_asset_alias_2.name)
- assert cond.objects == [Asset(uri=asset_1.uri)]
+ cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2)
+ assert cond.objects == [Asset(uri=asset_model.uri,
name=asset_model.name)]
def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
- for assset_alias in (asset_alias_1, resolved_asset_alias_2):
- cond = AssetAliasCondition(assset_alias.name)
- assert cond.as_expression() == {"alias": assset_alias.name}
+ for asset_alias in (asset_alias_1, resolved_asset_alias_2):
+ cond = AssetAliasCondition.from_asset_alias(asset_alias)
+ assert cond.as_expression() == {"alias": {"name":
asset_alias.name, "group": asset_alias.group}}
- def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1):
- cond = AssetAliasCondition(asset_alias_1.name)
- assert cond.evaluate({asset_1.uri: True}) is False
+ def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_model):
+ cond = AssetAliasCondition.from_asset_alias(asset_alias_1)
+ assert cond.evaluate({asset_model.uri: True}) is False
- cond = AssetAliasCondition(resolved_asset_alias_2.name)
- assert cond.evaluate({asset_1.uri: True}) is True
+ cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2)
+ assert cond.evaluate({asset_model.uri: True}) is True
class TestAssetSubclasses:
diff --git a/tests/api_connexion/endpoints/test_asset_endpoint.py
b/tests/api_connexion/endpoints/test_asset_endpoint.py
index db064ac5b44..57bea9c6643 100644
--- a/tests/api_connexion/endpoints/test_asset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_asset_endpoint.py
@@ -80,6 +80,8 @@ class TestAssetEndpoint:
asset_model = AssetModel(
id=1,
uri="s3://bucket/key",
+ name="asset-name",
+ group="asset",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
@@ -103,6 +105,8 @@ class TestGetAssetEndpoint(TestAssetEndpoint):
assert response.json == {
"id": 1,
"uri": "s3://bucket/key",
+ "name": "asset-name",
+ "group": "asset",
"extra": {"foo": "bar"},
"created_at": self.default_time,
"updated_at": self.default_time,
@@ -136,6 +140,8 @@ class TestGetAssets(TestAssetEndpoint):
AssetModel(
id=i,
uri=f"s3://bucket/key/{i}",
+ name=f"asset_{i}",
+ group="asset",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
@@ -156,6 +162,8 @@ class TestGetAssets(TestAssetEndpoint):
{
"id": 1,
"uri": "s3://bucket/key/1",
+ "name": "asset_1",
+ "group": "asset",
"extra": {"foo": "bar"},
"created_at": self.default_time,
"updated_at": self.default_time,
@@ -166,6 +174,8 @@ class TestGetAssets(TestAssetEndpoint):
{
"id": 2,
"uri": "s3://bucket/key/2",
+ "name": "asset_2",
+ "group": "asset",
"extra": {"foo": "bar"},
"created_at": self.default_time,
"updated_at": self.default_time,
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index 5b4133c6839..45e6bf53376 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -1738,7 +1738,7 @@ class TestClearDagRun(TestDagRunEndpoint):
@pytest.mark.need_serialized_dag
class TestGetDagRunAssetTriggerEvents(TestDagRunEndpoint):
def test_should_respond_200(self, dag_maker, session):
- asset1 = Asset(uri="ds1")
+ asset1 = Asset(uri="test://asset1", name="asset1")
with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(),
session=session):
EmptyOperator(task_id="task", outlets=[asset1])
diff --git a/tests/api_connexion/schemas/test_asset_schema.py
b/tests/api_connexion/schemas/test_asset_schema.py
index af5e8c08b86..ff5a81961e9 100644
--- a/tests/api_connexion/schemas/test_asset_schema.py
+++ b/tests/api_connexion/schemas/test_asset_schema.py
@@ -54,6 +54,8 @@ class TestAssetSchema(TestAssetSchemaBase):
def test_serialize(self, dag_maker, session):
asset = Asset(
uri="s3://bucket/key",
+ name="test_asset",
+ group="test-group",
extra={"foo": "bar"},
)
with dag_maker(dag_id="test_asset_upstream_schema", serialized=True,
session=session):
@@ -70,6 +72,8 @@ class TestAssetSchema(TestAssetSchemaBase):
assert serialized_data == {
"id": 1,
"uri": "s3://bucket/key",
+ "name": "test_asset",
+ "group": "test-group",
"extra": {"foo": "bar"},
"created_at": self.timestamp,
"updated_at": self.timestamp,
@@ -96,12 +100,14 @@ class TestAssetCollectionSchema(TestAssetSchemaBase):
def test_serialize(self, session):
assets = [
AssetModel(
- uri=f"s3://bucket/key/{i+1}",
+ uri=f"s3://bucket/key/{i}",
+ name=f"asset_{i}",
+ group="test-group",
extra={"foo": "bar"},
)
- for i in range(2)
+ for i in range(1, 3)
]
- asset_aliases = [AssetAliasModel(name=f"alias_{i}") for i in range(2)]
+ asset_aliases = [AssetAliasModel(name=f"alias_{i}",
group="test-alias-group") for i in range(2)]
for asset_alias in asset_aliases:
asset_alias.assets.append(assets[0])
session.add_all(assets)
@@ -117,19 +123,23 @@ class TestAssetCollectionSchema(TestAssetSchemaBase):
{
"id": 1,
"uri": "s3://bucket/key/1",
+ "name": "asset_1",
+ "group": "test-group",
"extra": {"foo": "bar"},
"created_at": self.timestamp,
"updated_at": self.timestamp,
"consuming_dags": [],
"producing_tasks": [],
"aliases": [
- {"id": 1, "name": "alias_0"},
- {"id": 2, "name": "alias_1"},
+ {"id": 1, "name": "alias_0", "group":
"test-alias-group"},
+ {"id": 2, "name": "alias_1", "group":
"test-alias-group"},
],
},
{
"id": 2,
"uri": "s3://bucket/key/2",
+ "name": "asset_2",
+ "group": "test-group",
"extra": {"foo": "bar"},
"created_at": self.timestamp,
"updated_at": self.timestamp,
diff --git a/tests/api_connexion/schemas/test_dag_schema.py
b/tests/api_connexion/schemas/test_dag_schema.py
index 4f1b07fb6e7..d6438045249 100644
--- a/tests/api_connexion/schemas/test_dag_schema.py
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -198,8 +198,8 @@ def
test_serialize_test_dag_detail_schema(url_safe_serializer):
@pytest.mark.db_test
def
test_serialize_test_dag_with_asset_schedule_detail_schema(url_safe_serializer):
- asset1 = Asset(uri="s3://bucket/obj1")
- asset2 = Asset(uri="s3://bucket/obj2")
+ asset1 = Asset(uri="s3://bucket/obj1", name="asset1")
+ asset2 = Asset(uri="s3://bucket/obj2", name="asset2")
dag = DAG(
dag_id="test_dag",
start_date=datetime(2020, 6, 19),
diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py
b/tests/api_fastapi/core_api/routes/public/test_assets.py
index a20353d32f8..9218cbbf820 100644
--- a/tests/api_fastapi/core_api/routes/public/test_assets.py
+++ b/tests/api_fastapi/core_api/routes/public/test_assets.py
@@ -722,7 +722,7 @@ class TestPostAssetEvents(TestAssets):
}
def test_invalid_attr_not_allowed(self, test_client, session):
- self.create_assets()
+ self.create_assets(session)
event_invalid_payload = {"asset_uri": "s3://bucket/key/1", "extra":
{"foo": "bar"}, "fake": {}}
response = test_client.post("/public/assets/events",
json=event_invalid_payload)
@@ -731,7 +731,7 @@ class TestPostAssetEvents(TestAssets):
@pytest.mark.usefixtures("time_freezer")
@pytest.mark.enable_redact
def test_should_mask_sensitive_extra(self, test_client, session):
- self.create_assets()
+ self.create_assets(session)
event_payload = {"uri": "s3://bucket/key/1", "extra": {"password":
"bar"}}
response = test_client.post("/public/assets/events",
json=event_payload)
assert response.status_code == 200
diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py
b/tests/api_fastapi/core_api/routes/ui/test_assets.py
index 8eafb0f8bdd..7b532918496 100644
--- a/tests/api_fastapi/core_api/routes/ui/test_assets.py
+++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py
@@ -36,7 +36,11 @@ def cleanup():
def test_next_run_assets(test_client, dag_maker):
- with dag_maker(dag_id="upstream",
schedule=[Asset(uri="s3://bucket/key/1")], serialized=True):
+ with dag_maker(
+ dag_id="upstream",
+ schedule=[Asset(uri="s3://bucket/next-run-asset/1", name="asset1")],
+ serialized=True,
+ ):
EmptyOperator(task_id="task1")
dag_maker.create_dagrun()
@@ -46,6 +50,16 @@ def test_next_run_assets(test_client, dag_maker):
assert response.status_code == 200
assert response.json() == {
- "asset_expression": {"all": ["s3://bucket/key/1"]},
- "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}],
+ "asset_expression": {
+ "all": [
+ {
+ "asset": {
+ "uri": "s3://bucket/next-run-asset/1",
+ "name": "asset1",
+ "group": "asset",
+ }
+ }
+ ]
+ },
+ "events": [{"id": 20, "uri": "s3://bucket/next-run-asset/1",
"lastUpdate": None}],
}
diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py
index aa8fbb03624..b716056e814 100644
--- a/tests/assets/test_manager.py
+++ b/tests/assets/test_manager.py
@@ -112,7 +112,7 @@ def create_mock_dag():
class TestAssetManager:
def test_register_asset_change_asset_doesnt_exist(self,
mock_task_instance):
- asset = Asset(uri="asset_doesnt_exist")
+ asset = Asset(uri="asset_doesnt_exist", name="not exist")
mock_session = mock.Mock()
# Gotta mock up the query results
@@ -131,12 +131,12 @@ class TestAssetManager:
def test_register_asset_change(self, session, dag_maker,
mock_task_instance):
asset_manager = AssetManager()
- asset = Asset(uri="test_asset_uri")
+ asset = Asset(uri="test://asset1", name="test_asset_uri",
group="asset")
dag1 = DagModel(dag_id="dag1", is_active=True)
dag2 = DagModel(dag_id="dag2", is_active=True)
session.add_all([dag1, dag2])
- asm = AssetModel(uri="test_asset_uri")
+ asm = AssetModel(uri="test://asset1/", name="test_asset_uri",
group="asset")
session.add(asm)
asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag.dag_id) for
dag in (dag1, dag2)]
session.execute(delete(AssetDagRunQueue))
@@ -155,10 +155,10 @@ class TestAssetManager:
consumer_dag_2 = DagModel(dag_id="conumser_2", is_active=True,
fileloc="dag2.py")
session.add_all([consumer_dag_1, consumer_dag_2])
- asm = AssetModel(uri="test_asset_uri")
+ asm = AssetModel(uri="test://asset1/", name="test_asset_uri",
group="asset")
session.add(asm)
- asam = AssetAliasModel(name="test_alias_name")
+ asam = AssetAliasModel(name="test_alias_name", group="test")
session.add(asam)
asam.consuming_dags = [
DagScheduleAssetAliasReference(alias_id=asam.id, dag_id=dag.dag_id)
@@ -167,8 +167,8 @@ class TestAssetManager:
session.execute(delete(AssetDagRunQueue))
session.flush()
- asset = Asset(uri="test_asset_uri")
- asset_alias = AssetAlias(name="test_alias_name")
+ asset = Asset(uri="test://asset1", name="test_asset_uri")
+ asset_alias = AssetAlias(name="test_alias_name", group="test")
asset_manager = AssetManager()
asset_manager.register_asset_change(
task_instance=mock_task_instance,
@@ -187,8 +187,8 @@ class TestAssetManager:
def test_register_asset_change_no_downstreams(self, session,
mock_task_instance):
asset_manager = AssetManager()
- asset = Asset(uri="never_consumed")
- asm = AssetModel(uri="never_consumed")
+ asset = Asset(uri="test://asset1", name="never_consumed")
+ asm = AssetModel(uri="test://asset1/", name="never_consumed",
group="asset")
session.add(asm)
session.execute(delete(AssetDagRunQueue))
session.flush()
@@ -205,11 +205,11 @@ class TestAssetManager:
asset_listener.clear()
get_listener_manager().add_listener(asset_listener)
- asset = Asset(uri="test_asset_uri_2")
+ asset = Asset(uri="test://asset1", name="test_asset_1")
dag1 = DagModel(dag_id="dag3")
session.add(dag1)
- asm = AssetModel(uri="test_asset_uri_2")
+ asm = AssetModel(uri="test://asset1/", name="test_asset_1",
group="asset")
session.add(asm)
asm.consuming_dags = [DagScheduleAssetReference(dag_id=dag1.dag_id)]
session.flush()
@@ -226,7 +226,7 @@ class TestAssetManager:
asset_listener.clear()
get_listener_manager().add_listener(asset_listener)
- asset = Asset(uri="test_asset_uri_3")
+ asset = Asset(uri="test://asset1", name="test_asset_1")
asms = asset_manager.create_assets([asset], session=session)
diff --git a/tests/dags/test_assets.py b/tests/dags/test_assets.py
index 1fbc67a18d3..6a0b08f9ba6 100644
--- a/tests/dags/test_assets.py
+++ b/tests/dags/test_assets.py
@@ -25,8 +25,8 @@ from airflow.providers.standard.operators.bash import
BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset
-skip_task_dag_asset = Asset("s3://dag_with_skip_task/output_1.txt",
extra={"hi": "bye"})
-fail_task_dag_asset = Asset("s3://dag_with_fail_task/output_1.txt",
extra={"hi": "bye"})
+skip_task_dag_asset = Asset(uri="s3://dag_with_skip_task/output_1.txt",
name="skip", extra={"hi": "bye"})
+fail_task_dag_asset = Asset(uri="s3://dag_with_fail_task/output_1.txt",
name="fail", extra={"hi": "bye"})
def raise_skip_exc():
diff --git a/tests/dags/test_only_empty_tasks.py
b/tests/dags/test_only_empty_tasks.py
index e5152f1f9ad..92c54649824 100644
--- a/tests/dags/test_only_empty_tasks.py
+++ b/tests/dags/test_only_empty_tasks.py
@@ -56,4 +56,6 @@ with dag:
EmptyOperator(task_id="test_task_on_success", on_success_callback=lambda
*args, **kwargs: None)
- EmptyOperator(task_id="test_task_outlets", outlets=[Asset("hello")])
+ EmptyOperator(
+ task_id="test_task_outlets", outlets=[Asset(name="hello",
uri="test://asset1", group="test-group")]
+ )
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index a90bccafa41..b53a379e861 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -975,12 +975,13 @@ def test_task_decorator_asset(dag_maker, session):
result = None
uri = "s3://bucket/name"
+ asset_name = "test_asset"
with dag_maker(session=session) as dag:
@dag.task()
def up1() -> Asset:
- return Asset(uri)
+ return Asset(uri=uri, name=asset_name)
@dag.task()
def up2(src: Asset) -> str:
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index fd9844bc4bc..29e67ca8464 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -405,7 +405,7 @@ class TestFs:
p = "s3"
f = "bucket/object"
- i = Asset(uri=f"{p}://{f}", extra={"foo": "bar"})
+ i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"})
o = ObjectStoragePath(i)
assert o.protocol == p
assert o.path == f
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 4ba0b6febf8..6b413135bfc 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -3979,8 +3979,8 @@ class TestSchedulerJob:
- That dag_model has next_dagrun
"""
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
+ asset1 = Asset(uri="test://asset1", name="test_asset",
group="test_group")
+ asset2 = Asset(uri="test://asset2", name="test_asset_2",
group="test_group")
with dag_maker(dag_id="assets-1", start_date=timezone.utcnow(),
session=session):
BashOperator(task_id="task", bash_command="echo 1",
outlets=[asset1])
@@ -4075,15 +4075,14 @@ class TestSchedulerJob:
],
)
def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker,
disable, enable):
- ds = Asset("ds")
- with dag_maker(dag_id="consumer", schedule=[ds], session=session):
+ asset = Asset(uri="test://asset_1", name="test_asset_1",
group="test_group")
+ with dag_maker(dag_id="consumer", schedule=[asset], session=session):
pass
with dag_maker(dag_id="producer", schedule="@daily", session=session):
- BashOperator(task_id="task", bash_command="echo 1", outlets=ds)
+ BashOperator(task_id="task", bash_command="echo 1", outlets=asset)
asset_manger = AssetManager()
- asset_id =
session.scalars(select(AssetModel.id).filter_by(uri=ds.uri)).one()
-
+ asset_id =
session.scalars(select(AssetModel.id).filter_by(uri=asset.uri,
name=asset.name)).one()
ase_q = select(AssetEvent).where(AssetEvent.asset_id ==
asset_id).order_by(AssetEvent.timestamp)
adrq_q = select(AssetDagRunQueue).where(
AssetDagRunQueue.asset_id == asset_id,
AssetDagRunQueue.target_dag_id == "consumer"
@@ -4096,7 +4095,7 @@ class TestSchedulerJob:
dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
asset_manger.register_asset_change(
task_instance=dr1.get_task_instance("task", session=session),
- asset=ds,
+ asset=asset,
session=session,
)
session.flush()
@@ -4110,7 +4109,7 @@ class TestSchedulerJob:
dr2: DagRun = dag_maker.create_dagrun_after(dr1,
run_type=DagRunType.SCHEDULED)
asset_manger.register_asset_change(
task_instance=dr2.get_task_instance("task", session=session),
- asset=ds,
+ asset=asset,
session=session,
)
session.flush()
@@ -6187,11 +6186,11 @@ class TestSchedulerJob:
def test_asset_orphaning(self, dag_maker, session):
self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
- asset3 = Asset(uri="ds3")
- asset4 = Asset(uri="ds4")
- asset5 = Asset(uri="ds5")
+ asset1 = Asset(uri="test://asset_1", name="test_asset_1",
group="test_group")
+ asset2 = Asset(uri="test://asset_2", name="test_asset_2",
group="test_group")
+ asset3 = Asset(uri="test://asset_3", name="test_asset_3",
group="test_group")
+ asset4 = Asset(uri="test://asset_4", name="test_asset_4",
group="test_group")
+ asset5 = Asset(uri="test://asset_5", name="test_asset_5",
group="test_group")
with dag_maker(dag_id="assets-1", schedule=[asset1, asset2],
session=session):
BashOperator(task_id="task", bash_command="echo 1",
outlets=[asset3, asset4])
@@ -6230,7 +6229,7 @@ class TestSchedulerJob:
def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session):
self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
- asset1 = Asset(uri="ds1")
+ asset1 = Asset(uri="test://asset_1", name="test_asset_1",
group="test_group")
with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
BashOperator(task_id="task", bash_command="echo 1")
@@ -6303,11 +6302,13 @@ class TestSchedulerJob:
asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1",
extra=asset_extra)
asset1_1 = Asset(name=asset1_name, uri="it's duplicate",
extra=asset_extra)
- dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1,
asset1_1])
+ asset1_2 = Asset(name="it's also a duplicate",
uri="s3://bucket/key/1", extra=asset_extra)
+ dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1,
asset1_1, asset1_2])
DAG.bulk_write_to_db([dag1], session=session)
asset_models = session.scalars(select(AssetModel)).all()
+ assert len(asset_models) == 3
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
@@ -6318,8 +6319,10 @@ class TestSchedulerJob:
)
)
assert dag_warning.message == (
- "Cannot activate asset AssetModel(name='asset1', uri=\"it's
duplicate\", extra={'foo': 'bar'}); "
- "name is already associated to 's3://bucket/key/1'"
+ 'Cannot activate asset AssetModel(name="it\'s also a duplicate",'
+ " uri='s3://bucket/key/1', extra={'foo': 'bar'}); uri is already
associated to 'asset1'\n"
+ "Cannot activate asset AssetModel(name='asset1', uri"
+ "=\"it's duplicate\", extra={'foo': 'bar'}); name is already
associated to 's3://bucket/key/1'"
)
def test_activate_referenced_assets_with_existing_warnings(self, session):
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index ec6390c77a5..f66f6c2bf9f 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -46,13 +46,26 @@ class TestHookLineageCollector:
assert self.collector.collected_assets == HookLineage()
input_hook = BaseHook()
output_hook = BaseHook()
- self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file")
- self.collector.add_output_asset(output_hook,
uri="postgres://example.com:5432/database/default/table")
+ self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file",
name="asset-1", group="test")
+ self.collector.add_output_asset(
+ output_hook,
+ uri="postgres://example.com:5432/database/default/table",
+ )
assert self.collector.collected_assets == HookLineage(
- [AssetLineageInfo(asset=Asset("s3://in_bucket/file"), count=1,
context=input_hook)],
[
AssetLineageInfo(
-
asset=Asset("postgres://example.com:5432/database/default/table"),
+ asset=Asset(uri="s3://in_bucket/file", name="asset-1",
group="test"),
+ count=1,
+ context=input_hook,
+ )
+ ],
+ [
+ AssetLineageInfo(
+ asset=Asset(
+
uri="postgres://example.com:5432/database/default/table",
+
name="postgres://example.com:5432/database/default/table",
+ group="asset",
+ ),
count=1,
context=output_hook,
)
@@ -68,7 +81,7 @@ class TestHookLineageCollector:
self.collector.add_input_asset(hook, uri="test_uri")
assert next(iter(self.collector._inputs.values())) == (asset, hook)
- mock_asset.assert_called_once_with(uri="test_uri", extra=None)
+ mock_asset.assert_called_once_with(uri="test_uri")
def test_grouping_assets(self):
hook_1 = MagicMock()
@@ -95,18 +108,29 @@ class TestHookLineageCollector:
@patch("airflow.lineage.hook.ProvidersManager")
def test_create_asset(self, mock_providers_manager):
def create_asset(arg1, arg2="default", extra=None):
- return Asset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {})
+ return Asset(
+ uri=f"myscheme://{arg1}/{arg2}", name=f"asset-{arg1}",
group="test", extra=extra or {}
+ )
mock_providers_manager.return_value.asset_factories = {"myscheme":
create_asset}
assert self.collector.create_asset(
- scheme="myscheme", uri=None, asset_kwargs={"arg1": "value_1"},
asset_extra=None
- ) == Asset("myscheme://value_1/default")
+ scheme="myscheme",
+ uri=None,
+ name=None,
+ group=None,
+ asset_kwargs={"arg1": "value_1"},
+ asset_extra=None,
+ ) == Asset(uri="myscheme://value_1/default", name="asset-value_1",
group="test")
assert self.collector.create_asset(
scheme="myscheme",
uri=None,
+ name=None,
+ group=None,
asset_kwargs={"arg1": "value_1", "arg2": "value_2"},
asset_extra={"key": "value"},
- ) == Asset("myscheme://value_1/value_2", extra={"key": "value"})
+ ) == Asset(
+ uri="myscheme://value_1/value_2", name="asset-value_1",
group="test", extra={"key": "value"}
+ )
@patch("airflow.lineage.hook.ProvidersManager")
def test_create_asset_no_factory(self, mock_providers_manager):
@@ -117,7 +141,12 @@ class TestHookLineageCollector:
assert (
self.collector.create_asset(
- scheme=test_scheme, uri=None, asset_kwargs=test_kwargs,
asset_extra=None
+ scheme=test_scheme,
+ uri=None,
+ name=None,
+ group=None,
+ asset_kwargs=test_kwargs,
+ asset_extra=None,
)
is None
)
diff --git a/tests/listeners/test_asset_listener.py
b/tests/listeners/test_asset_listener.py
index 7acf122829d..ace800358f2 100644
--- a/tests/listeners/test_asset_listener.py
+++ b/tests/listeners/test_asset_listener.py
@@ -41,9 +41,11 @@ def clean_listener_manager():
@pytest.mark.db_test
@provide_session
def
test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator,
session):
- asset_uri = "test_asset_uri"
- asset = Asset(uri=asset_uri)
- asset_model = AssetModel(uri=asset_uri)
+ asset_uri = "test://asset/"
+ asset_name = "test_asset_uri"
+ asset_group = "test-group"
+ asset = Asset(uri=asset_uri, name=asset_name, group=asset_group)
+ asset_model = AssetModel(uri=asset_uri, name=asset_name, group=asset_group)
session.add(asset_model)
session.flush()
@@ -59,3 +61,5 @@ def
test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_oper
assert len(asset_listener.changed) == 1
assert asset_listener.changed[0].uri == asset_uri
+ assert asset_listener.changed[0].name == asset_name
+ assert asset_listener.changed[0].group == asset_group
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a651c7114d6..384d76c7548 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -857,15 +857,24 @@ class TestDag:
"""
dag_id1 = "test_asset_dag1"
dag_id2 = "test_asset_dag2"
+
task_id = "test_asset_task"
+
uri1 = "s3://asset/1"
- a1 = Asset(uri1, extra={"not": "used"})
- a2 = Asset("s3://asset/2")
- a3 = Asset("s3://asset/3")
+ a1 = Asset(uri=uri1, name="test_asset_1", extra={"not": "used"},
group="test-group")
+ a2 = Asset(uri="s3://asset/2", name="test_asset_2", group="test-group")
+ a3 = Asset(uri="s3://asset/3", name="test_asset-3", group="test-group")
+
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[a1])
EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2, a3])
+
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None)
- EmptyOperator(task_id=task_id, dag=dag2, outlets=[Asset(uri1,
extra={"should": "be used"})])
+ EmptyOperator(
+ task_id=task_id,
+ dag=dag2,
+ outlets=[Asset(uri=uri1, name="test_asset_1", extra={"should": "be
used"}, group="test-group")],
+ )
+
session = settings.Session()
dag1.clear()
DAG.bulk_write_to_db([dag1, dag2], session=session)
@@ -934,10 +943,10 @@ class TestDag:
"""
# Create four assets - two that have references and two that are
unreferenced and marked as
# orphans
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
- asset3 = Asset(uri="ds3")
- asset4 = Asset(uri="ds4")
+ asset1 = Asset(uri="test://asset1", name="asset1", group="test-group")
+ asset2 = Asset(uri="test://asset2", name="asset2", group="test-group")
+ asset3 = Asset(uri="test://asset3", name="asset3", group="test-group")
+ asset4 = Asset(uri="test://asset4", name="asset4", group="test-group")
dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE,
schedule=[asset1])
BashOperator(dag=dag1, task_id="task", bash_command="echo 1",
outlets=[asset3])
@@ -1407,8 +1416,11 @@ class TestDag:
assert dag.timetable.description == interval_description
def test_timetable_and_description_from_asset(self):
- dag = DAG("test_schedule_interval_arg", schedule=[Asset(uri="hello")],
start_date=TEST_DATE)
- assert dag.timetable == AssetTriggeredTimetable(Asset(uri="hello"))
+ uri = "test://asset"
+ dag = DAG(
+ "test_schedule_interval_arg", schedule=[Asset(uri=uri,
group="test-group")], start_date=TEST_DATE
+ )
+ assert dag.timetable == AssetTriggeredTimetable(Asset(uri=uri,
group="test-group"))
assert dag.timetable.description == "Triggered by assets"
@pytest.mark.parametrize(
@@ -2173,7 +2185,7 @@ class TestDagModel:
session.close()
def test_dags_needing_dagruns_assets(self, dag_maker, session):
- asset = Asset(uri="hello")
+ asset = Asset(uri="test://asset", group="test-group")
with dag_maker(
session=session,
dag_id="my_dag",
@@ -2405,8 +2417,8 @@ class TestDagModel:
@pytest.mark.need_serialized_dag
def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self,
session, dag_maker):
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
+ asset1 = Asset(uri="test://asset1", group="test-group")
+ asset2 = Asset(uri="test://asset2", name="test_asset_2",
group="test-group")
for dag_id, asset in [("assets-1", asset1), ("assets-2", asset2)]:
with dag_maker(dag_id=dag_id, start_date=timezone.utcnow(),
session=session):
@@ -2455,12 +2467,17 @@ class TestDagModel:
dag = DAG(
dag_id="test_dag_asset_expression",
schedule=AssetAny(
- Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}),
+ Asset(uri="s3://dag1/output_1.txt", extra={"hi": "bye"},
group="test-group"),
AssetAll(
- Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}),
- Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}),
+ Asset(
+ uri="s3://dag2/output_1.txt",
+ name="test_asset_2",
+ extra={"hi": "bye"},
+ group="test-group",
+ ),
+ Asset("s3://dag3/output_3.txt", extra={"hi": "bye"},
group="test-group"),
),
- AssetAlias(name="test_name"),
+ AssetAlias(name="test_name", group="test-group"),
),
start_date=datetime.datetime.min,
)
@@ -2469,9 +2486,32 @@ class TestDagModel:
expression =
session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one()
assert expression == {
"any": [
- "s3://dag1/output_1.txt",
- {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]},
- {"alias": "test_name"},
+ {
+ "asset": {
+ "uri": "s3://dag1/output_1.txt",
+ "name": "s3://dag1/output_1.txt",
+ "group": "test-group",
+ }
+ },
+ {
+ "all": [
+ {
+ "asset": {
+ "uri": "s3://dag2/output_1.txt",
+ "name": "test_asset_2",
+ "group": "test-group",
+ }
+ },
+ {
+ "asset": {
+ "uri": "s3://dag3/output_3.txt",
+ "name": "s3://dag3/output_3.txt",
+ "group": "test-group",
+ }
+ },
+ ]
+ },
+ {"alias": {"name": "test_name", "group": "test-group"}},
]
}
@@ -3026,9 +3066,9 @@ def test__time_restriction(dag_maker, dag_date,
tasks_date, restrict):
@pytest.mark.need_serialized_dag
def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
- asset3 = Asset(uri="ds3")
+ asset1 = Asset(uri="test://asset1", name="test_asset1", group="test-group")
+ asset2 = Asset(uri="test://asset2", group="test-group")
+ asset3 = Asset(uri="test://asset3", group="test-group")
with dag_maker(dag_id="assets-1", schedule=[asset2]):
pass
dag1 = dag_maker.dag
diff --git a/tests/models/test_serialized_dag.py
b/tests/models/test_serialized_dag.py
index 011e785626e..41632fe0458 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -243,16 +243,16 @@ class TestSerializedDagModel:
dag_id="example",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[
- Asset("1"),
- Asset("2"),
- Asset("3"),
- Asset("4"),
- Asset("5"),
+ Asset(uri="test://asset1", name="1"),
+ Asset(uri="test://asset2", name="2"),
+ Asset(uri="test://asset3", name="3"),
+ Asset(uri="test://asset4", name="4"),
+ Asset(uri="test://asset5", name="5"),
],
) as dag6:
BashOperator(
task_id="any",
- outlets=[Asset("0*"), Asset("6*")],
+ outlets=[Asset(uri="test://asset0", name="0*"),
Asset(uri="test://asset6", name="6*")],
bash_command="sleep 5",
)
deps_order = [x["dependency_id"] for x in
SerializedDAG.serialize_dag(dag6)["dag_dependencies"]]
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index d7dbf54c186..3955d17477b 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -88,7 +88,12 @@ from airflow.utils.task_group import TaskGroup
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests_common.test_utils.compat import BaseOperatorLink
-from tests_common.test_utils.mock_operators import AirflowLink2,
CustomOperator, GoogleLink, MockOperator
+from tests_common.test_utils.mock_operators import (
+ AirflowLink2,
+ CustomOperator,
+ GoogleLink,
+ MockOperator,
+)
from tests_common.test_utils.timetables import (
CustomSerializationTimetable,
cron_timetable,
@@ -105,7 +110,10 @@ executor_config_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name="my-name"),
spec=k8s.V1PodSpec(
containers=[
- k8s.V1Container(name="base",
volume_mounts=[k8s.V1VolumeMount(name="my-vol", mount_path="/vol/")])
+ k8s.V1Container(
+ name="base",
+ volume_mounts=[k8s.V1VolumeMount(name="my-vol",
mount_path="/vol/")],
+ )
]
),
)
@@ -133,7 +141,10 @@ serialized_simple_dag_ground_truth = {
"task_group": {
"_group_id": None,
"prefix_group_id": True,
- "children": {"bash_task": ("operator", "bash_task"),
"custom_task": ("operator", "custom_task")},
+ "children": {
+ "bash_task": ("operator", "bash_task"),
+ "custom_task": ("operator", "custom_task"),
+ },
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
@@ -161,7 +172,10 @@ serialized_simple_dag_ground_truth = {
"ui_fgcolor": "#000",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env", "cwd"],
- "template_fields_renderers": {"bash_command": "bash",
"env": "json"},
+ "template_fields_renderers": {
+ "bash_command": "bash",
+ "env": "json",
+ },
"bash_command": "echo {{ task.task_id }}",
"task_type": "BashOperator",
"_task_module":
"airflow.providers.standard.operators.bash",
@@ -223,7 +237,10 @@ serialized_simple_dag_ground_truth = {
"__var": {
"DAGs": {
"__type": "set",
- "__var": [permissions.ACTION_CAN_READ,
permissions.ACTION_CAN_EDIT],
+ "__var": [
+ permissions.ACTION_CAN_READ,
+ permissions.ACTION_CAN_EDIT,
+ ],
}
},
}
@@ -462,7 +479,10 @@ class TestStringifiedDAGs:
serialized_dag = SerializedDAG.to_dict(dag)
SerializedDAG.validate_schema(serialized_dag)
- assert serialized_dag["dag"]["access_control"] == {"__type": "dict",
"__var": {}}
+ assert serialized_dag["dag"]["access_control"] == {
+ "__type": "dict",
+ "__var": {},
+ }
@pytest.mark.db_test
def test_dag_serialization_unregistered_custom_timetable(self):
@@ -690,14 +710,21 @@ class TestStringifiedDAGs:
default_partial_kwargs = (
BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY,
strict=False).partial_kwargs
)
- serialized_partial_kwargs = {**default_partial_kwargs,
**serialized_task.partial_kwargs}
+ serialized_partial_kwargs = {
+ **default_partial_kwargs,
+ **serialized_task.partial_kwargs,
+ }
original_partial_kwargs = {**default_partial_kwargs,
**task.partial_kwargs}
assert serialized_partial_kwargs == original_partial_kwargs
@pytest.mark.parametrize(
"dag_start_date, task_start_date, expected_task_start_date",
[
- (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019,
8, 1, tzinfo=timezone.utc)),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ None,
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ ),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 8, 2, tzinfo=timezone.utc),
@@ -749,7 +776,11 @@ class TestStringifiedDAGs:
@pytest.mark.parametrize(
"dag_end_date, task_end_date, expected_task_end_date",
[
- (datetime(2019, 8, 1, tzinfo=timezone.utc), None, datetime(2019,
8, 1, tzinfo=timezone.utc)),
+ (
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ None,
+ datetime(2019, 8, 1, tzinfo=timezone.utc),
+ ),
(
datetime(2019, 8, 1, tzinfo=timezone.utc),
datetime(2019, 8, 2, tzinfo=timezone.utc),
@@ -763,7 +794,12 @@ class TestStringifiedDAGs:
],
)
def test_deserialization_end_date(self, dag_end_date, task_end_date,
expected_task_end_date):
- dag = DAG(dag_id="simple_dag", schedule=None,
start_date=datetime(2019, 8, 1), end_date=dag_end_date)
+ dag = DAG(
+ dag_id="simple_dag",
+ schedule=None,
+ start_date=datetime(2019, 8, 1),
+ end_date=dag_end_date,
+ )
BaseOperator(task_id="simple_task", dag=dag, end_date=task_end_date)
serialized_dag = SerializedDAG.to_dict(dag)
@@ -781,7 +817,10 @@ class TestStringifiedDAGs:
@pytest.mark.parametrize(
"serialized_timetable, expected_timetable",
[
- ({"__type": "airflow.timetables.simple.NullTimetable", "__var":
{}}, NullTimetable()),
+ (
+ {"__type": "airflow.timetables.simple.NullTimetable", "__var":
{}},
+ NullTimetable(),
+ ),
(
{
"__type":
"airflow.timetables.interval.CronDataIntervalTimetable",
@@ -789,7 +828,10 @@ class TestStringifiedDAGs:
},
cron_timetable("0 0 * * 0"),
),
- ({"__type": "airflow.timetables.simple.OnceTimetable", "__var":
{}}, OnceTimetable()),
+ (
+ {"__type": "airflow.timetables.simple.OnceTimetable", "__var":
{}},
+ OnceTimetable(),
+ ),
(
{
"__type":
"airflow.timetables.interval.DeltaDataIntervalTimetable",
@@ -848,12 +890,24 @@ class TestStringifiedDAGs:
@pytest.mark.parametrize(
"val, expected",
[
- (relativedelta(days=-1), {"__type": "relativedelta", "__var":
{"days": -1}}),
- (relativedelta(month=1, days=-1), {"__type": "relativedelta",
"__var": {"month": 1, "days": -1}}),
+ (
+ relativedelta(days=-1),
+ {"__type": "relativedelta", "__var": {"days": -1}},
+ ),
+ (
+ relativedelta(month=1, days=-1),
+ {"__type": "relativedelta", "__var": {"month": 1, "days": -1}},
+ ),
# Every friday
- (relativedelta(weekday=FR), {"__type": "relativedelta", "__var":
{"weekday": [4]}}),
+ (
+ relativedelta(weekday=FR),
+ {"__type": "relativedelta", "__var": {"weekday": [4]}},
+ ),
# Every second friday
- (relativedelta(weekday=FR(2)), {"__type": "relativedelta",
"__var": {"weekday": [4, 2]}}),
+ (
+ relativedelta(weekday=FR(2)),
+ {"__type": "relativedelta", "__var": {"weekday": [4, 2]}},
+ ),
],
)
def test_roundtrip_relativedelta(self, val, expected):
@@ -913,7 +967,11 @@ class TestStringifiedDAGs:
schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
super().__init__(default=path, schema=schema)
- dag = DAG(dag_id="simple_dag", schedule=None, params={"path":
S3Param("s3://my_bucket/my_path")})
+ dag = DAG(
+ dag_id="simple_dag",
+ schedule=None,
+ params={"path": S3Param("s3://my_bucket/my_path")},
+ )
with pytest.raises(SerializationError):
SerializedDAG.to_dict(dag)
@@ -968,11 +1026,21 @@ class TestStringifiedDAGs:
dag = DAG(dag_id="simple_dag", schedule=None)
if expected_val == ParamValidationError:
with pytest.raises(ParamValidationError):
- BaseOperator(task_id="simple_task", dag=dag, params=val,
start_date=datetime(2019, 8, 1))
+ BaseOperator(
+ task_id="simple_task",
+ dag=dag,
+ params=val,
+ start_date=datetime(2019, 8, 1),
+ )
# further tests not relevant
return
else:
- BaseOperator(task_id="simple_task", dag=dag, params=val,
start_date=datetime(2019, 8, 1))
+ BaseOperator(
+ task_id="simple_task",
+ dag=dag,
+ params=val,
+ start_date=datetime(2019, 8, 1),
+ )
serialized_dag = SerializedDAG.to_dict(dag)
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
@@ -1130,10 +1198,19 @@ class TestStringifiedDAGs:
("{{ task.task_id }}", "{{ task.task_id }}"),
(["{{ task.task_id }}", "{{ task.task_id }}"]),
({"foo": "{{ task.task_id }}"}, {"foo": "{{ task.task_id }}"}),
- ({"foo": {"bar": "{{ task.task_id }}"}}, {"foo": {"bar": "{{
task.task_id }}"}}),
(
- [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{
task.task_id }}"}}],
- [{"foo1": {"bar": "{{ task.task_id }}"}}, {"foo2": {"bar": "{{
task.task_id }}"}}],
+ {"foo": {"bar": "{{ task.task_id }}"}},
+ {"foo": {"bar": "{{ task.task_id }}"}},
+ ),
+ (
+ [
+ {"foo1": {"bar": "{{ task.task_id }}"}},
+ {"foo2": {"bar": "{{ task.task_id }}"}},
+ ],
+ [
+ {"foo1": {"bar": "{{ task.task_id }}"}},
+ {"foo2": {"bar": "{{ task.task_id }}"}},
+ ],
),
(
{"foo": {"bar": {"{{ task.task_id }}": ["sar"]}}},
@@ -1141,7 +1218,9 @@ class TestStringifiedDAGs:
),
(
ClassWithCustomAttributes(
- att1="{{ task.task_id }}", att2="{{ task.task_id }}",
template_fields=["att1"]
+ att1="{{ task.task_id }}",
+ att2="{{ task.task_id }}",
+ template_fields=["att1"],
),
"ClassWithCustomAttributes("
"{'att1': '{{ task.task_id }}', 'att2': '{{ task.task_id }}',
'template_fields': ['att1']})",
@@ -1149,10 +1228,14 @@ class TestStringifiedDAGs:
(
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
- att1="{{ task.task_id }}", att2="{{ task.task_id }}",
template_fields=["att1"]
+ att1="{{ task.task_id }}",
+ att2="{{ task.task_id }}",
+ template_fields=["att1"],
),
nested2=ClassWithCustomAttributes(
- att3="{{ task.task_id }}", att4="{{ task.task_id }}",
template_fields=["att3"]
+ att3="{{ task.task_id }}",
+ att4="{{ task.task_id }}",
+ template_fields=["att3"],
),
template_fields=["nested1"],
),
@@ -1172,7 +1255,11 @@ class TestStringifiedDAGs:
we want check that non-"basic" objects are turned in to strings after
deserializing.
"""
- dag = DAG("test_serialized_template_fields", schedule=None,
start_date=datetime(2019, 8, 1))
+ dag = DAG(
+ "test_serialized_template_fields",
+ schedule=None,
+ start_date=datetime(2019, 8, 1),
+ )
with dag:
BashOperator(task_id="test", bash_command=templated_field)
@@ -1410,7 +1497,11 @@ class TestStringifiedDAGs:
"""
logical_date = datetime(2020, 1, 1)
- with DAG("test_task_group_setup_teardown_tasks", schedule=None,
start_date=logical_date) as dag:
+ with DAG(
+ "test_task_group_setup_teardown_tasks",
+ schedule=None,
+ start_date=logical_date,
+ ) as dag:
EmptyOperator(task_id="setup").as_setup()
EmptyOperator(task_id="teardown").as_teardown()
@@ -1580,7 +1671,11 @@ class TestStringifiedDAGs:
deps = frozenset([*BaseOperator.deps, CustomTestTriggerRule()])
logical_date = datetime(2020, 1, 1)
- with DAG(dag_id="test_serialize_custom_ti_deps", schedule=None,
start_date=logical_date) as dag:
+ with DAG(
+ dag_id="test_serialize_custom_ti_deps",
+ schedule=None,
+ start_date=logical_date,
+ ) as dag:
DummyTask(task_id="task1")
serialize_op =
SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
@@ -1668,20 +1763,26 @@ class TestStringifiedDAGs:
"""
from airflow.providers.standard.sensors.external_task import
ExternalTaskSensor
- d1 = Asset("d1")
- d2 = Asset("d2")
- d3 = Asset("d3")
- d4 = Asset("d4")
+ asset1 = Asset(name="asset1", uri="test://asset1")
+ asset2 = Asset(name="asset2", uri="test://asset2")
+ asset3 = Asset(name="asset3", uri="test://asset3")
+ asset4 = Asset(name="asset4", uri="test://asset4")
logical_date = datetime(2020, 1, 1)
- with DAG(dag_id="test", start_date=logical_date, schedule=[d1, d1, d1,
d1, d1]) as dag:
+ with DAG(
+ dag_id="test", start_date=logical_date, schedule=[asset1, asset1,
asset1, asset1, asset1]
+ ) as dag:
ExternalTaskSensor(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
)
- BashOperator(task_id="asset_writer", bash_command="echo hello",
outlets=[d2, d2, d2, d3])
+ BashOperator(
+ task_id="asset_writer",
+ bash_command="echo hello",
+ outlets=[asset2, asset2, asset2, asset3],
+ )
- @dag.task(outlets=[d4])
+ @dag.task(outlets=[asset4])
def other_asset_writer(x):
pass
@@ -1695,7 +1796,7 @@ class TestStringifiedDAGs:
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d4",
+ "dependency_id": "asset4",
},
{
"source": "external_dag_id",
@@ -1707,40 +1808,40 @@ class TestStringifiedDAGs:
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d3",
+ "dependency_id": "asset3",
},
{
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d2",
+ "dependency_id": "asset2",
},
{
"source": "asset",
"target": "test",
"dependency_type": "asset",
- "dependency_id": "d1",
+ "dependency_id": "asset1",
},
{
- "dependency_id": "d1",
+ "dependency_id": "asset1",
"dependency_type": "asset",
"source": "asset",
"target": "test",
},
{
- "dependency_id": "d1",
+ "dependency_id": "asset1",
"dependency_type": "asset",
"source": "asset",
"target": "test",
},
{
- "dependency_id": "d1",
+ "dependency_id": "asset1",
"dependency_type": "asset",
"source": "asset",
"target": "test",
},
{
- "dependency_id": "d1",
+ "dependency_id": "asset1",
"dependency_type": "asset",
"source": "asset",
"target": "test",
@@ -1757,20 +1858,20 @@ class TestStringifiedDAGs:
"""
from airflow.providers.standard.sensors.external_task import
ExternalTaskSensor
- d1 = Asset("d1")
- d2 = Asset("d2")
- d3 = Asset("d3")
- d4 = Asset("d4")
+ asset1 = Asset(name="asset1", uri="test://asset1")
+ asset2 = Asset(name="asset2", uri="test://asset2")
+ asset3 = Asset(name="asset3", uri="test://asset3")
+ asset4 = Asset(name="asset4", uri="test://asset4")
logical_date = datetime(2020, 1, 1)
- with DAG(dag_id="test", start_date=logical_date, schedule=[d1]) as dag:
+ with DAG(dag_id="test", start_date=logical_date, schedule=[asset1]) as
dag:
ExternalTaskSensor(
task_id="task1",
external_dag_id="external_dag_id",
mode="reschedule",
)
- BashOperator(task_id="asset_writer", bash_command="echo hello",
outlets=[d2, d3])
+ BashOperator(task_id="asset_writer", bash_command="echo hello",
outlets=[asset2, asset3])
- @dag.task(outlets=[d4])
+ @dag.task(outlets=[asset4])
def other_asset_writer(x):
pass
@@ -1784,7 +1885,7 @@ class TestStringifiedDAGs:
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d4",
+ "dependency_id": "asset4",
},
{
"source": "external_dag_id",
@@ -1796,19 +1897,19 @@ class TestStringifiedDAGs:
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d3",
+ "dependency_id": "asset3",
},
{
"source": "test",
"target": "asset",
"dependency_type": "asset",
- "dependency_id": "d2",
+ "dependency_id": "asset2",
},
{
"source": "asset",
"target": "test",
"dependency_type": "asset",
- "dependency_id": "d1",
+ "dependency_id": "asset1",
},
],
key=lambda x: tuple(x.values()),
@@ -1821,14 +1922,20 @@ class TestStringifiedDAGs:
Tests DAG dependency detection for operators, including derived classes
"""
from airflow.operators.empty import EmptyOperator
- from airflow.providers.standard.operators.trigger_dagrun import
TriggerDagRunOperator
+ from airflow.providers.standard.operators.trigger_dagrun import (
+ TriggerDagRunOperator,
+ )
class DerivedOperator(TriggerDagRunOperator):
pass
logical_date = datetime(2020, 1, 1)
for class_ in [TriggerDagRunOperator, DerivedOperator]:
- with DAG(dag_id="test_derived_dag_deps_trigger", schedule=None,
start_date=logical_date) as dag:
+ with DAG(
+ dag_id="test_derived_dag_deps_trigger",
+ schedule=None,
+ start_date=logical_date,
+ ) as dag:
task1 = EmptyOperator(task_id="task1")
if mapped:
task2 = class_.partial(
@@ -1912,7 +2019,10 @@ class TestStringifiedDAGs:
assert upstream_group_ids == ["task_group_up1", "task_group_up2"]
upstream_task_ids = task_group_middle_dict["upstream_task_ids"]
- assert upstream_task_ids == ["task_group_up1.task_up1",
"task_group_up2.task_up2"]
+ assert upstream_task_ids == [
+ "task_group_up1.task_up1",
+ "task_group_up2.task_up2",
+ ]
downstream_group_ids = task_group_middle_dict["downstream_group_ids"]
assert downstream_group_ids == ["task_group_down1", "task_group_down2"]
@@ -1930,7 +2040,11 @@ class TestStringifiedDAGs:
from airflow.operators.empty import EmptyOperator
from airflow.utils.edgemodifier import Label
- with DAG("test_edge_info_serialization", schedule=None,
start_date=datetime(2020, 1, 1)) as dag:
+ with DAG(
+ "test_edge_info_serialization",
+ schedule=None,
+ start_date=datetime(2020, 1, 1),
+ ) as dag:
task1 = EmptyOperator(task_id="task1")
task2 = EmptyOperator(task_id="task2")
task1 >> Label("test label") >> task2
@@ -2024,7 +2138,11 @@ class TestStringifiedDAGs:
When the callback is not set, has_on_failure_callback should not be
stored in Serialized blob
and so default to False on de-serialization
"""
- dag = DAG(dag_id="test_dag_on_failure_callback_roundtrip",
schedule=None, **passed_failure_callback)
+ dag = DAG(
+ dag_id="test_dag_on_failure_callback_roundtrip",
+ schedule=None,
+ **passed_failure_callback,
+ )
BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019,
8, 1))
serialized_dag = SerializedDAG.to_dict(dag)
@@ -2116,7 +2234,12 @@ class TestStringifiedDAGs:
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
- "params": {"my_param": {"__class":
"airflow.models.param.Param", "default": "str"}},
+ "params": {
+ "my_param": {
+ "__class": "airflow.models.param.Param",
+ "default": "str",
+ }
+ },
},
}
dag = SerializedDAG.from_dict(serialized)
@@ -2265,7 +2388,10 @@ class TestStringifiedDAGs:
"__type": "START_TRIGGER_ARGS",
"trigger_cls":
"airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
# "trigger_kwargs": {"__type": "dict", "__var": {"delta":
{"__type": "timedelta", "__var": 2.0}}},
- "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type":
"timedelta", "__var": 2.0}}},
+ "trigger_kwargs": {
+ "__type": "dict",
+ "__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
+ },
"next_method": "execute_complete",
"next_kwargs": None,
"timeout": None,
@@ -2400,7 +2526,12 @@ def test_operator_expand_xcomarg_serde():
"type": "dict-of-lists",
"value": {
"__type": "dict",
- "__var": {"arg2": {"__type": "xcomref", "__var": {"task_id":
"op1", "key": "return_value"}}},
+ "__var": {
+ "arg2": {
+ "__type": "xcomref",
+ "__var": {"task_id": "op1", "key": "return_value"},
+ }
+ },
},
},
"partial_kwargs": {},
@@ -2457,7 +2588,12 @@ def test_operator_expand_kwargs_literal_serde(strict):
{"__type": "dict", "__var": {"a": "x"}},
{
"__type": "dict",
- "__var": {"a": {"__type": "xcomref", "__var": {"task_id":
"op1", "key": "return_value"}}},
+ "__var": {
+ "a": {
+ "__type": "xcomref",
+ "__var": {"task_id": "op1", "key": "return_value"},
+ }
+ },
},
],
},
@@ -2481,12 +2617,18 @@ def test_operator_expand_kwargs_literal_serde(strict):
# The XComArg can't be deserialized before the DAG is.
expand_value = op.expand_input.value
- assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1",
"key": XCOM_RETURN_KEY})}]
+ assert expand_value == [
+ {"a": "x"},
+ {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
+ ]
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
resolved_expand_value =
serialized_dag.task_dict["task_2"].expand_input.value
- resolved_expand_value == [{"a": "x"}, {"a":
PlainXComArg(serialized_dag.task_dict["op1"])}]
+ resolved_expand_value == [
+ {"a": "x"},
+ {"a": PlainXComArg(serialized_dag.task_dict["op1"])},
+ ]
@pytest.mark.parametrize("strict", [True, False])
@@ -2508,7 +2650,10 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
"downstream_task_ids": [],
"expand_input": {
"type": "list-of-dicts",
- "value": {"__type": "xcomref", "__var": {"task_id": "op1", "key":
"return_value"}},
+ "value": {
+ "__type": "xcomref",
+ "__var": {"task_id": "op1", "key": "return_value"},
+ },
},
"partial_kwargs": {},
"task_id": "task_2",
@@ -2640,7 +2785,10 @@ def test_taskflow_expand_serde():
"__type": "dict",
"__var": {
"arg2": {"__type": "dict", "__var": {"a": 1, "b": 2}},
- "arg3": {"__type": "xcomref", "__var": {"task_id": "op1",
"key": "return_value"}},
+ "arg3": {
+ "__type": "xcomref",
+ "__var": {"task_id": "op1", "key": "return_value"},
+ },
},
},
},
@@ -2650,7 +2798,11 @@ def test_taskflow_expand_serde():
"task_id": "x",
"template_ext": [],
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
- "template_fields_renderers": {"templates_dict": "json", "op_args":
"py", "op_kwargs": "py"},
+ "template_fields_renderers": {
+ "templates_dict": "json",
+ "op_args": "py",
+ "op_kwargs": "py",
+ },
"_disallow_kwargs_override": False,
"_expand_input_attr": "op_kwargs_expand_input",
"python_callable_name": qualname(x),
@@ -2666,7 +2818,10 @@ def test_taskflow_expand_serde():
assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
- value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1",
"key": XCOM_RETURN_KEY})},
+ value={
+ "arg2": {"a": 1, "b": 2},
+ "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
+ },
)
assert deserialized.partial_kwargs == {
"is_setup": False,
@@ -2688,7 +2843,10 @@ def test_taskflow_expand_serde():
pickled = pickle.loads(pickle.dumps(deserialized))
assert pickled.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
- value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1",
"key": XCOM_RETURN_KEY})},
+ value={
+ "arg2": {"a": 1, "b": 2},
+ "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
+ },
)
assert pickled.partial_kwargs == {
"is_setup": False,
@@ -2753,7 +2911,11 @@ def test_taskflow_expand_kwargs_serde(strict):
"task_id": "x",
"template_ext": [],
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
- "template_fields_renderers": {"templates_dict": "json", "op_args":
"py", "op_kwargs": "py"},
+ "template_fields_renderers": {
+ "templates_dict": "json",
+ "op_args": "py",
+ "op_kwargs": "py",
+ },
"_disallow_kwargs_override": strict,
"_expand_input_attr": "op_kwargs_expand_input",
}
diff --git a/tests/serialization/test_serde.py
b/tests/serialization/test_serde.py
index a3a946124ff..2fc8ad8d17b 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -365,7 +365,7 @@ class TestSerDe:
assert e["extra"] == {"hi": "bye"}
def test_encode_asset(self):
- asset = Asset("mytest://asset")
+ asset = Asset(uri="mytest://asset", name="test")
obj = deserialize(serialize(asset))
assert asset.uri == obj.uri
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 75ff736be87..3e8e8445288 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -187,7 +187,11 @@ class MockLazySelectSequence(LazySelectSequence):
(timezone.utcnow(), DAT.DATETIME, equal_time),
(timedelta(minutes=2), DAT.TIMEDELTA, equals),
(Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
- (relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a,
b: a.hours == b.hours),
+ (
+ relativedelta.relativedelta(hours=+1),
+ DAT.RELATIVEDELTA,
+ lambda a, b: a.hours == b.hours,
+ ),
({"test": "dict", "test-1": 1}, None, equals),
(["array_item", 2], None, equals),
(("tuple_item", 3), DAT.TUPLE, equals),
@@ -195,7 +199,9 @@ class MockLazySelectSequence(LazySelectSequence):
(
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
- name="test", annotations={"test": "annotation"},
creation_timestamp=timezone.utcnow()
+ name="test",
+ annotations={"test": "annotation"},
+ creation_timestamp=timezone.utcnow(),
)
),
DAT.POD,
@@ -214,7 +220,14 @@ class MockLazySelectSequence(LazySelectSequence):
),
(Resources(cpus=0.1, ram=2048), None, None),
(EmptyOperator(task_id="test-task"), None, None),
- (TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag",
start_date=datetime.now())), None, None),
+ (
+ TaskGroup(
+ group_id="test-group",
+ dag=DAG(dag_id="test_dag", start_date=datetime.now()),
+ ),
+ None,
+ None,
+ ),
(
Param("test", "desc"),
DAT.PARAM,
@@ -231,8 +244,12 @@ class MockLazySelectSequence(LazySelectSequence):
DAT.XCOM_REF,
None,
),
- (MockLazySelectSequence(), None, lambda a, b: len(a) == len(b) and
isinstance(b, list)),
- (Asset(uri="test"), DAT.ASSET, equals),
+ (
+ MockLazySelectSequence(),
+ None,
+ lambda a, b: len(a) == len(b) and isinstance(b, list),
+ ),
+ (Asset(uri="test://asset1", name="test"), DAT.ASSET, equals),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),
@@ -240,16 +257,24 @@ class MockLazySelectSequence(LazySelectSequence):
lambda a, b: a.get_uri() == b.get_uri(),
),
(
- OutletEventAccessor(raw_key=Asset(uri="test"), extra={"key":
"value"}, asset_alias_events=[]),
+ OutletEventAccessor(
+ raw_key=Asset(uri="test://asset1", name="test",
group="test-group"),
+ extra={"key": "value"},
+ asset_alias_events=[],
+ ),
DAT.ASSET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
(
OutletEventAccessor(
- raw_key=AssetAlias(name="test_alias"),
+ raw_key=AssetAlias(name="test_alias",
group="test-alias-group"),
extra={"key": "value"},
asset_alias_events=[
- AssetAliasEvent(source_alias_name="test_alias",
dest_asset_uri="test_uri", extra={})
+ AssetAliasEvent(
+ source_alias_name="test_alias",
+ dest_asset_uri="test_uri",
+ extra={},
+ )
],
),
DAT.ASSET_EVENT_ACCESSOR,
@@ -295,7 +320,10 @@ def test_serialize_deserialize(input, encoded_type,
cmp_func):
"conn_uri",
[
pytest.param("aws://", id="only-conn-type"),
-
pytest.param("postgres://username:[email protected]:5432/the_database",
id="all-non-extra"),
+ pytest.param(
+ "postgres://username:[email protected]:5432/the_database",
+ id="all-non-extra",
+ ),
pytest.param(
"///?__extra__=%7B%22foo%22%3A+%22bar%22%2C+%22answer%22%3A+42%2C+%22"
"nullable%22%3A+null%2C+%22empty%22%3A+%22%22%2C+%22zero%22%3A+0%7D",
@@ -307,7 +335,10 @@ def test_backcompat_deserialize_connection(conn_uri):
"""Test deserialize connection which serialised by previous serializer
implementation."""
from airflow.serialization.serialized_objects import BaseSerialization
- conn_obj = {Encoding.TYPE: DAT.CONNECTION, Encoding.VAR: {"conn_id":
"TEST_ID", "uri": conn_uri}}
+ conn_obj = {
+ Encoding.TYPE: DAT.CONNECTION,
+ Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri},
+ }
deserialized = BaseSerialization.deserialize(conn_obj)
assert deserialized.get_uri() == conn_uri
@@ -323,10 +354,13 @@ sample_objects = {
is_paused=True,
),
LogTemplatePydantic: LogTemplate(
- id=1, filename="test_file", elasticsearch_id="test_id",
created_at=datetime.now()
+ id=1,
+ filename="test_file",
+ elasticsearch_id="test_id",
+ created_at=datetime.now(),
),
DagTagPydantic: DagTag(),
- AssetPydantic: Asset("uri", extra={}),
+ AssetPydantic: Asset(name="test", uri="test://asset1", extra={}),
AssetEventPydantic: AssetEvent(),
}
diff --git a/tests/timetables/test_assets_timetable.py
b/tests/timetables/test_assets_timetable.py
index 9d572295773..c8c889f603c 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -105,7 +105,7 @@ def test_timetable() -> MockTimetable:
@pytest.fixture
def test_assets() -> list[Asset]:
"""Pytest fixture for creating a list of Asset objects."""
- return [Asset("test_asset")]
+ return [Asset(name="test_asset", uri="test://asset")]
@pytest.fixture
@@ -134,7 +134,15 @@ def test_serialization(asset_timetable:
AssetOrTimeSchedule, monkeypatch: Any) -
"timetable": "mock_serialized_timetable",
"asset_condition": {
"__type": "asset_all",
- "objects": [{"__type": "asset", "uri": "test_asset", "name":
"test_asset", "extra": {}}],
+ "objects": [
+ {
+ "__type": "asset",
+ "name": "test_asset",
+ "uri": "test://asset/",
+ "group": "asset",
+ "extra": {},
+ }
+ ],
},
}
@@ -152,7 +160,15 @@ def test_deserialization(monkeypatch: Any) -> None:
"timetable": "mock_serialized_timetable",
"asset_condition": {
"__type": "asset_all",
- "objects": [{"__type": "asset", "name": "test_asset", "uri":
"test_asset", "extra": None}],
+ "objects": [
+ {
+ "__type": "asset",
+ "name": "test_asset",
+ "uri": "test://asset/",
+ "group": "asset",
+ "extra": None,
+ }
+ ],
},
}
deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data)
diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py
index b99681c2231..d5d0cdb32e8 100644
--- a/tests/utils/test_json.py
+++ b/tests/utils/test_json.py
@@ -86,7 +86,7 @@ class TestXComEncoder:
)
def test_encode_xcom_asset(self):
- asset = Asset("mytest://asset")
+ asset = Asset(uri="mytest://asset", name="mytest")
s = json.dumps(asset, cls=utils_json.XComEncoder)
obj = json.loads(s, cls=utils_json.XComDecoder)
assert asset.uri == obj.uri
diff --git a/tests/www/views/test_views_asset.py
b/tests/www/views/test_views_asset.py
index e4fda0aeac6..2e6668f134a 100644
--- a/tests/www/views/test_views_asset.py
+++ b/tests/www/views/test_views_asset.py
@@ -42,7 +42,10 @@ class TestAssetEndpoint:
@pytest.fixture
def create_assets(self, session):
def create(indexes):
- assets = [AssetModel(id=i, uri=f"s3://bucket/key/{i}") for i in
indexes]
+ assets = [
+ AssetModel(id=i, uri=f"s3://bucket/key/{i}",
name=f"asset-{i}", group="asset")
+ for i in indexes
+ ]
session.add_all(assets)
session.flush()
session.add_all(AssetActive.for_asset(a) for a in assets)
@@ -220,7 +223,7 @@ class TestGetAssets(TestAssetEndpoint):
@pytest.mark.need_serialized_dag
def test_correct_counts_update(self, admin_client, session, dag_maker,
app, monkeypatch):
with monkeypatch.context() as m:
- assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2, 3, 4,
5]]
+ assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"asset-{i}") for
i in range(1, 6)]
# DAG that produces asset #1
with dag_maker(dag_id="upstream", schedule=None, serialized=True,
session=session):
@@ -399,7 +402,9 @@ class TestGetAssetsEndpointPagination(TestAssetEndpoint):
class TestGetAssetNextRunSummary(TestAssetEndpoint):
def test_next_run_asset_summary(self, dag_maker, admin_client):
- with dag_maker(dag_id="upstream",
schedule=[Asset(uri="s3://bucket/key/1")], serialized=True):
+ with dag_maker(
+ dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1",
name="asset-1")], serialized=True
+ ):
EmptyOperator(task_id="task1")
response = admin_client.post("/next_run_assets_summary",
data={"dag_ids": ["upstream"]})
diff --git a/tests/www/views/test_views_grid.py
b/tests/www/views/test_views_grid.py
index 067ca9325f6..e2181aa702d 100644
--- a/tests/www/views/test_views_grid.py
+++ b/tests/www/views/test_views_grid.py
@@ -431,8 +431,8 @@ def test_has_outlet_asset_flag(admin_client, dag_maker,
session, app, monkeypatc
lineagefile = File("/tmp/does_not_exist")
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2", outlets=[lineagefile])
- EmptyOperator(task_id="task3", outlets=[Asset("foo"), lineagefile])
- EmptyOperator(task_id="task4", outlets=[Asset("foo")])
+ EmptyOperator(task_id="task3", outlets=[Asset(name="foo",
uri="s3://bucket/key"), lineagefile])
+ EmptyOperator(task_id="task4", outlets=[Asset(name="foo",
uri="s3://bucket/key")])
m.setattr(app, "dag_bag", dag_maker.dagbag)
resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}",
follow_redirects=True)
@@ -471,7 +471,7 @@ def test_has_outlet_asset_flag(admin_client, dag_maker,
session, app, monkeypatc
@pytest.mark.need_serialized_dag
def test_next_run_assets(admin_client, dag_maker, session, app, monkeypatch):
with monkeypatch.context() as m:
- assets = [Asset(uri=f"s3://bucket/key/{i}") for i in [1, 2]]
+ assets = [Asset(uri=f"s3://bucket/key/{i}", name=f"name_{i}",
group="test-group") for i in [1, 2]]
with dag_maker(dag_id=DAG_ID, schedule=assets, serialized=True,
session=session):
EmptyOperator(task_id="task1")
@@ -508,7 +508,12 @@ def test_next_run_assets(admin_client, dag_maker, session,
app, monkeypatch):
assert resp.status_code == 200, resp.json
assert resp.json == {
- "asset_expression": {"all": ["s3://bucket/key/1",
"s3://bucket/key/2"]},
+ "asset_expression": {
+ "all": [
+ {"asset": {"uri": "s3://bucket/key/1", "name": "name_1",
"group": "test-group"}},
+ {"asset": {"uri": "s3://bucket/key/2", "name": "name_2",
"group": "test-group"}},
+ ]
+ },
"events": [
{"id": asset1_id, "uri": "s3://bucket/key/1", "lastUpdate":
"2022-08-02T02:00:00+00:00"},
{"id": asset2_id, "uri": "s3://bucket/key/2", "lastUpdate": None},