This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c4a10e718ae YAML first discovery for connection form metadata (#60410)
c4a10e718ae is described below
commit c4a10e718aef09a629cc3e87af5446be90e66938
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Feb 13 12:01:41 2026 +0530
YAML first discovery for connection form metadata (#60410)
Add declarative UI metadata for connection forms in provider.yaml
Load conn-fields and ui-field-behaviour from provider info instead of
importing hook classes.
---
.../core_api/services/ui/connections.py | 13 +-
.../src/airflow/cli/commands/provider_command.py | 4 +-
airflow-core/src/airflow/provider.yaml.schema.json | 100 ++++++++
airflow-core/src/airflow/providers_manager.py | 188 +++++++++++----
.../tests/unit/always/test_providers_manager.py | 229 ++++++++++++++++++
.../23_provider_hook_migration_to_yaml.rst | 110 +++++++++
.../fab/www/views/test_connection_form_fields.py | 90 +++++--
providers/google/provider.yaml | 66 ++++++
.../airflow/providers/google/get_provider_info.py | 51 ++++
providers/http/provider.yaml | 5 +
.../airflow/providers/http/get_provider_info.py | 7 +-
providers/smtp/provider.yaml | 4 +
.../airflow/providers/smtp/get_provider_info.py | 10 +-
scripts/ci/prek/check_core_imports_in_shared.py | 106 +++++++++
scripts/tools/generate_yaml_format_for_hooks.py | 262 +++++++++++++++++++++
15 files changed, 1172 insertions(+), 73 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
index 50ab6a7d69c..4568eb07278 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections.py
@@ -204,12 +204,19 @@ class HookMetaService:
result: dict[str, MutableMapping] = {}
for key, form_widget in form_widgets.items():
hook_key = key.split("__")[1]
- if isinstance(form_widget.field, HookMetaService.MockBaseField):
- hook_widgets = result.get(hook_key, {})
+ hook_widgets = result.get(hook_key, {})
+
+ if isinstance(form_widget.field, dict):
+ # yaml path, form widgets read from yaml and already present
in SerializedParam.dump() format
+ hook_widgets[form_widget.field_name] = form_widget.field
+ elif isinstance(form_widget.field, HookMetaService.MockBaseField):
+ # legacy path, form widgets created using mocked WTForms
fields, need to convert to SerializedParam.dump()
hook_widgets[form_widget.field_name] =
form_widget.field.param.dump()
- result[hook_key] = hook_widgets
else:
log.error("Unknown form widget in %s: %s", hook_key,
form_widget)
+ continue
+
+ result[hook_key] = hook_widgets
return result
@staticmethod
diff --git a/airflow-core/src/airflow/cli/commands/provider_command.py
b/airflow-core/src/airflow/cli/commands/provider_command.py
index 81a96b2a9bc..645618fd852 100644
--- a/airflow-core/src/airflow/cli/commands/provider_command.py
+++ b/airflow-core/src/airflow/cli/commands/provider_command.py
@@ -137,7 +137,9 @@ def connection_form_widget_list(args):
"connection_parameter_name": x[0],
"class": x[1].hook_class_name,
"package_name": x[1].package_name,
- "field_type": x[1].field.field_class.__name__,
+ "field_type": x[1].field.get("schema", {}).get("type", "unknown")
+ if isinstance(x[1].field, dict)
+ else x[1].field.field_class.__name__,
},
)
diff --git a/airflow-core/src/airflow/provider.yaml.schema.json
b/airflow-core/src/airflow/provider.yaml.schema.json
index 917d84d78ce..cbe3e8044be 100644
--- a/airflow-core/src/airflow/provider.yaml.schema.json
+++ b/airflow-core/src/airflow/provider.yaml.schema.json
@@ -359,6 +359,106 @@
"hook-class-name": {
"description": "Hook class name that implements the
connection type",
"type": "string"
+ },
+ "ui-field-behaviour": {
+ "description": "Customizations for standard connection
form fields",
+ "type": "object",
+ "properties": {
+ "hidden-fields": {
+ "description": "List of standard fields to
hide in the UI",
+ "type": "array",
+ "items": {
+ "type": "string",
+ "enum": ["description", "host", "port",
"login", "password", "schema", "extra"]
+ },
+ "default": []
+ },
+ "relabeling": {
+ "description": "Map of field names to custom
labels",
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ },
+ "default": {}
+ },
+ "placeholders": {
+ "description": "Map of field names to
placeholder text",
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ },
+ "default": {}
+ }
+ },
+ "additionalProperties": false
+ },
+ "conn-fields": {
+ "description": "Custom connection fields stored in
Connection.extra JSON",
+ "type": "object",
+ "additionalProperties": {
+ "type": "object",
+ "properties": {
+ "label": {
+ "type": "string",
+ "description": "Display label for the
field"
+ },
+ "description": {
+ "type": "string",
+ "description": "Help text for the field"
+ },
+ "schema": {
+ "description": "JSON Schema definition for
this field",
+ "type": "object",
+ "properties": {
+ "type": {
+ "description": "Field data type",
+ "oneOf": [
+ {"type": "string", "enum":
["string", "integer", "boolean", "number", "object", "array"]},
+ {"type": "array", "items":
{"type": "string"}}
+ ]
+ },
+ "default": {
+ "description": "Default value for
the field"
+ },
+ "format": {
+ "type": "string",
+ "enum": ["password", "multiline",
"email", "url", "json", "date", "date-time", "time"],
+ "description": "Format hint for
rendering"
+ },
+ "enum": {
+ "type": "array",
+ "description": "List of allowed
values (creates dropdown)"
+ },
+ "minimum": {
+ "type": "number",
+ "description": "Minimum value for
numbers"
+ },
+ "maximum": {
+ "type": "number",
+ "description": "Maximum value for
numbers"
+ },
+ "pattern": {
+ "type": "string",
+ "description": "Regex pattern for
validation"
+ },
+ "minLength": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "Minimum string
length"
+ },
+ "maxLength": {
+ "type": "integer",
+ "minimum": 1,
+ "description": "Maximum string
length"
+ }
+ },
+ "required": ["type"],
+ "additionalProperties": true
+ }
+ },
+ "required": ["label", "schema"],
+ "additionalProperties": false
+ }
}
},
"required": [
diff --git a/airflow-core/src/airflow/providers_manager.py
b/airflow-core/src/airflow/providers_manager.py
index 956f52857bb..84cb95d656f 100644
--- a/airflow-core/src/airflow/providers_manager.py
+++ b/airflow-core/src/airflow/providers_manager.py
@@ -525,6 +525,7 @@ class ProvidersManager(LoggingMixin):
self._init_airflow_core_hooks()
self.initialize_providers_list()
self._discover_hooks()
+ self._load_ui_metadata()
self._hook_provider_dict =
dict(sorted(self._hook_provider_dict.items()))
@provider_info_cache("filesystems")
@@ -902,6 +903,98 @@ class ProvidersManager(LoggingMixin):
return None
return getattr(obj, attr_name)
+ def _get_connection_type_config(self, provider_info: ProviderInfo,
connection_type: str) -> dict | None:
+ """Get connection type config from provider.yaml if it exists."""
+ connection_types = provider_info.data.get("connection-types", [])
+ for conn_config in connection_types:
+ if conn_config.get("connection-type") == connection_type:
+ return conn_config
+ return None
+
+ def _to_api_format(self, field_name: str, field_def: dict) -> dict:
+ """Convert conn-fields definition to format expected by the API."""
+ schema_def = field_def.get("schema", {})
+
+ # build schema dict with label moved to `title` per jsonschema
convention
+ schema = schema_def.copy()
+ if "label" in field_def:
+ schema["title"] = field_def.get("label")
+
+ return {
+ "value": schema_def.get("default"),
+ "schema": schema,
+ "description": field_def.get("description"),
+ "source": None,
+ }
+
+ def _add_widgets(
+ self, package_name: str, hook_class_name: str, connection_type: str,
conn_fields: dict
+ ) -> None:
+ """Parse conn-fields from provider info and add to
connection_form_widgets."""
+ for field_name, field_def in conn_fields.items():
+ field_data = self._to_api_format(field_name, field_def)
+
+ prefixed_name = f"extra__{connection_type}__{field_name}"
+ if prefixed_name in self._connection_form_widgets:
+ log.warning(
+ "Field %s for connection type %s already added, skipping",
+ field_name,
+ connection_type,
+ )
+ continue
+
+ schema_def = field_def.get("schema", {})
+ self._connection_form_widgets[prefixed_name] =
ConnectionFormWidgetInfo(
+ hook_class_name=hook_class_name,
+ package_name=package_name,
+ field=field_data,
+ field_name=field_name,
+ is_sensitive=schema_def.get("format") == "password",
+ )
+
+ def _add_customized_fields(self, package_name: str, connection_type: str,
behaviour: dict) -> None:
+ """Process ui-field-behaviour from provider info and add to
field_behaviours."""
+ if connection_type in self._field_behaviours:
+ log.warning(
+ "Field behaviour for connection type %s already exists,
skipping",
+ connection_type,
+ )
+ return
+
+ # convert kebab-case keys to python style
+ customized_fields = {
+ "hidden_fields": behaviour.get("hidden-fields", []),
+ "relabeling": behaviour.get("relabeling", {}),
+ "placeholders": behaviour.get("placeholders", {}),
+ }
+
+ try:
+
self._customized_form_fields_schema_validator.validate(customized_fields)
+ customized_fields =
_ensure_prefix_for_placeholders(customized_fields, connection_type)
+ self._field_behaviours[connection_type] = customized_fields
+ except Exception as e:
+ log.warning(
+ "Failed to add field behaviour for %s in package %s: %s",
+ connection_type,
+ package_name,
+ e,
+ )
+
+ def _load_ui_metadata(self) -> None:
+ """Load connection form UI metadata from provider info without
importing hooks."""
+ for package_name, provider in self._provider_dict.items():
+ for conn_config in provider.data.get("connection-types", []):
+ connection_type = conn_config.get("connection-type")
+ hook_class_name = conn_config.get("hook-class-name")
+ if not connection_type or not hook_class_name:
+ continue
+
+ if conn_fields := conn_config.get("conn-fields"):
+ self._add_widgets(package_name, hook_class_name,
connection_type, conn_fields)
+
+ if behaviour := conn_config.get("ui-field-behaviour"):
+ self._add_customized_fields(package_name, connection_type,
behaviour)
+
def _import_hook(
self,
connection_type: str | None,
@@ -942,48 +1035,59 @@ class ProvidersManager(LoggingMixin):
hook_class: type[BaseHook] | None = _correctness_check(package_name,
hook_class_name, provider_info)
if hook_class is None:
return None
- try:
- from wtforms import BooleanField, IntegerField, PasswordField,
StringField
-
- allowed_field_classes = [IntegerField, PasswordField, StringField,
BooleanField]
- module, class_name = hook_class_name.rsplit(".", maxsplit=1)
- # Do not use attr here. We want to check only direct class fields
not those
- # inherited from parent hook. This way we add form fields only
once for the whole
- # hierarchy and we add it only from the parent hook that provides
those!
- if "get_connection_form_widgets" in hook_class.__dict__:
- widgets = hook_class.get_connection_form_widgets()
- if widgets:
- for widget in widgets.values():
- if widget.field_class not in allowed_field_classes:
- log.warning(
- "The hook_class '%s' uses field of unsupported
class '%s'. "
- "Only '%s' field classes are supported",
- hook_class_name,
- widget.field_class,
- allowed_field_classes,
- )
- return None
- self._add_widgets(package_name, hook_class, widgets)
- if "get_ui_field_behaviour" in hook_class.__dict__:
- field_behaviours = hook_class.get_ui_field_behaviour()
- if field_behaviours:
- self._add_customized_fields(package_name, hook_class,
field_behaviours)
- except ImportError as e:
- if e.name in ["flask_appbuilder", "wtforms"]:
- log.info(
- "The hook_class '%s' is not fully initialized (UI widgets
will be missing), because "
- "the 'flask_appbuilder' package is not installed, however
it is not required for "
- "Airflow components to work",
+
+ # Check if provider info already has UI metadata and skip Python hook
methods
+ # to avoid duplicate initialization and unnecessary wtforms imports
+ ui_metadata_loaded = False
+ if provider_info and connection_type:
+ conn_config = self._get_connection_type_config(provider_info,
connection_type)
+ ui_metadata_loaded = conn_config is not None and bool(
+ conn_config.get("conn-fields") or
conn_config.get("ui-field-behaviour")
+ )
+
+ if not ui_metadata_loaded:
+ try:
+ from wtforms import BooleanField, IntegerField, PasswordField,
StringField
+
+ allowed_field_classes = [IntegerField, PasswordField,
StringField, BooleanField]
+ # Do not use attr here. We want to check only direct class
fields not those
+ # inherited from parent hook. This way we add form fields only
once for the whole
+ # hierarchy and we add it only from the parent hook that
provides those!
+ if "get_connection_form_widgets" in hook_class.__dict__:
+ widgets = hook_class.get_connection_form_widgets()
+ if widgets:
+ for widget in widgets.values():
+ if widget.field_class not in allowed_field_classes:
+ log.warning(
+ "The hook_class '%s' uses field of
unsupported class '%s'. "
+ "Only '%s' field classes are supported",
+ hook_class_name,
+ widget.field_class,
+ allowed_field_classes,
+ )
+ return None
+ self._add_widgets_from_hook(package_name, hook_class,
widgets)
+ if "get_ui_field_behaviour" in hook_class.__dict__:
+ field_behaviours = hook_class.get_ui_field_behaviour()
+ if field_behaviours:
+ self._add_customized_fields_from_hook(package_name,
hook_class, field_behaviours)
+ except ImportError as e:
+ if e.name in ["flask_appbuilder", "wtforms"]:
+ log.info(
+ "The hook_class '%s' is not fully initialized (UI
widgets will be missing), because "
+ "the 'flask_appbuilder' package is not installed,
however it is not required for "
+ "Airflow components to work",
+ hook_class_name,
+ )
+ except Exception as e:
+ log.warning(
+ "Exception when importing '%s' from '%s' package: %s",
hook_class_name,
+ package_name,
+ e,
)
- except Exception as e:
- log.warning(
- "Exception when importing '%s' from '%s' package: %s",
- hook_class_name,
- package_name,
- e,
- )
- return None
+ return None
+
hook_connection_type = self._get_attr(hook_class, "conn_type")
if connection_type:
if hook_connection_type != connection_type:
@@ -1019,7 +1123,7 @@ class ProvidersManager(LoggingMixin):
connection_testable=hasattr(hook_class, "test_connection"),
)
- def _add_widgets(self, package_name: str, hook_class: type, widgets:
dict[str, Any]):
+ def _add_widgets_from_hook(self, package_name: str, hook_class: type,
widgets: dict[str, Any]):
conn_type = hook_class.conn_type # type: ignore
for field_identifier, field in widgets.items():
if field_identifier.startswith("extra__"):
@@ -1043,7 +1147,7 @@ class ProvidersManager(LoggingMixin):
and field.field_class.widget.input_type == "password",
)
- def _add_customized_fields(self, package_name: str, hook_class: type,
customized_fields: dict):
+ def _add_customized_fields_from_hook(self, package_name: str, hook_class:
type, customized_fields: dict):
try:
connection_type = getattr(hook_class, "conn_type")
diff --git a/airflow-core/tests/unit/always/test_providers_manager.py
b/airflow-core/tests/unit/always/test_providers_manager.py
index 7d7cc0507dd..22bffa471a3 100644
--- a/airflow-core/tests/unit/always/test_providers_manager.py
+++ b/airflow-core/tests/unit/always/test_providers_manager.py
@@ -273,3 +273,232 @@ class TestWithoutCheckProviderManager:
mock_correctness_check.assert_not_called()
assert providers_manager._executor_without_check_set == result
+
+
[email protected](
+ ("value", "expected_outputs"),
+ [
+ ("a", "a"),
+ (1, 1),
+ (None, None),
+ (lambda: 0, 0),
+ (lambda: None, None),
+ (lambda: "z", "z"),
+ ],
+)
+def test_lazy_cache_dict_resolving(value, expected_outputs):
+ lazy_cache_dict = LazyDictWithCache()
+ lazy_cache_dict["key"] = value
+ assert lazy_cache_dict["key"] == expected_outputs
+ # Retrieve it again to see if it is correctly returned again
+ assert lazy_cache_dict["key"] == expected_outputs
+
+
+def test_lazy_cache_dict_raises_error():
+ def raise_method():
+ raise RuntimeError("test")
+
+ lazy_cache_dict = LazyDictWithCache()
+ lazy_cache_dict["key"] = raise_method
+ with pytest.raises(RuntimeError, match="test"):
+ _ = lazy_cache_dict["key"]
+
+
+def test_lazy_cache_dict_del_item():
+ lazy_cache_dict = LazyDictWithCache()
+
+ def answer():
+ return 42
+
+ lazy_cache_dict["spam"] = answer
+ assert "spam" in lazy_cache_dict._raw_dict
+ assert "spam" not in lazy_cache_dict._resolved # Not resoled yet
+ assert lazy_cache_dict["spam"] == 42
+ assert "spam" in lazy_cache_dict._resolved
+ del lazy_cache_dict["spam"]
+ assert "spam" not in lazy_cache_dict._raw_dict
+ assert "spam" not in lazy_cache_dict._resolved
+
+ lazy_cache_dict["foo"] = answer
+ assert lazy_cache_dict["foo"] == 42
+ assert "foo" in lazy_cache_dict._resolved
+ # Emulate some mess in data, e.g. value from `_raw_dict` deleted but not
from `_resolved`
+ del lazy_cache_dict._raw_dict["foo"]
+ assert "foo" in lazy_cache_dict._resolved
+ with pytest.raises(KeyError):
+ # Error expected here, but we still expect to remove also record into
`resolved`
+ del lazy_cache_dict["foo"]
+ assert "foo" not in lazy_cache_dict._resolved
+
+ lazy_cache_dict["baz"] = answer
+ # Key in `_resolved` not created yet
+ assert "baz" in lazy_cache_dict._raw_dict
+ assert "baz" not in lazy_cache_dict._resolved
+ del lazy_cache_dict._raw_dict["baz"]
+ assert "baz" not in lazy_cache_dict._raw_dict
+ assert "baz" not in lazy_cache_dict._resolved
+
+
+def test_lazy_cache_dict_clear():
+ def answer():
+ return 42
+
+ lazy_cache_dict = LazyDictWithCache()
+ assert len(lazy_cache_dict) == 0
+ lazy_cache_dict["spam"] = answer
+ lazy_cache_dict["foo"] = answer
+ lazy_cache_dict["baz"] = answer
+
+ assert len(lazy_cache_dict) == 3
+ assert len(lazy_cache_dict._raw_dict) == 3
+ assert not lazy_cache_dict._resolved
+ assert lazy_cache_dict["spam"] == 42
+ assert len(lazy_cache_dict._resolved) == 1
+ # Emulate some mess in data, contain some data into the `_resolved`
+ lazy_cache_dict._resolved.add("biz")
+ assert len(lazy_cache_dict) == 3
+ assert len(lazy_cache_dict._resolved) == 2
+ # And finally cleanup everything
+ lazy_cache_dict.clear()
+ assert len(lazy_cache_dict) == 0
+ assert not lazy_cache_dict._raw_dict
+ assert not lazy_cache_dict._resolved
+
+
+class TestProvidersMetadataLoading:
+ @pytest.mark.parametrize(
+ ("field_name", "field_def", "expected_title", "expected_checks"),
+ [
+ pytest.param(
+ "api_url",
+ {
+ "label": "API URL",
+ "description": "The API endpoint URL",
+ "schema": {
+ "type": "string",
+ "default": "https://api.example.com",
+ },
+ },
+ "API URL",
+ lambda x: (
+ x["description"] == "The API endpoint URL" and x["value"]
== "https://api.example.com"
+ ),
+ id="string_field",
+ ),
+ pytest.param(
+ "timeout",
+ {
+ "label": "Timeout",
+ "description": "Connection timeout in seconds",
+ "schema": {
+ "type": "integer",
+ "minimum": 1,
+ "maximum": 300,
+ "default": 30,
+ },
+ },
+ "Timeout",
+ lambda x: x["value"] == 30,
+ id="integer_field",
+ ),
+ pytest.param(
+ "use_ssl",
+ {
+ "label": "Use SSL",
+ "schema": {
+ "type": "boolean",
+ "default": True,
+ },
+ },
+ "Use SSL",
+ lambda x: x["value"] is True,
+ id="boolean_field",
+ ),
+ pytest.param(
+ "api_key",
+ {
+ "label": "API Key",
+ "sensitive": True,
+ "schema": {
+ "type": "string",
+ "format": "password",
+ },
+ },
+ "API Key",
+ lambda x: x["schema"].get("format") == "password",
+ id="password_field",
+ ),
+ pytest.param(
+ "ssl_mode",
+ {
+ "label": "SSL Mode",
+ "description": "SSL connection mode",
+ "schema": {
+ "type": "string",
+ "enum": ["disable", "prefer", "require",
"verify-full"],
+ "default": "prefer",
+ },
+ },
+ "SSL Mode",
+ lambda x: (
+ x["value"] == "prefer"
+ and "enum" in x["schema"]
+ and x["schema"]["enum"] == ["disable", "prefer",
"require", "verify-full"]
+ ),
+ id="enum_field",
+ ),
+ ],
+ )
+ def test_to_api_format(self, field_name, field_def, expected_title,
expected_checks):
+ """Test converting field definitions to API format."""
+ pm = ProvidersManager()
+ x = pm._to_api_format(field_name, field_def)
+
+ assert x is not None
+ assert isinstance(x, dict)
+ assert x["schema"]["title"] == expected_title
+ assert expected_checks(x)
+
+ def test_add_customized_fields(self):
+ """Test adding customized field behaviour from provider info."""
+ pm = ProvidersManager()
+ pm.initialize_providers_list()
+
+ behaviour = {
+ "hidden-fields": ["schema", "extra"],
+ "relabeling": {"login": "Email Address"},
+ "placeholders": {"host": "smtp.gmail.com", "port": "587"},
+ }
+
+ pm._add_customized_fields(
+ package_name="test-provider", connection_type="test_conn",
behaviour=behaviour
+ )
+
+ assert "test_conn" in pm._field_behaviours
+ behaviour = pm._field_behaviours["test_conn"]
+ assert behaviour["hidden_fields"] == ["schema", "extra"]
+ assert behaviour["relabeling"] == {"login": "Email Address"}
+ assert behaviour["placeholders"]["host"] == "smtp.gmail.com"
+
+ def test_load_ui_for_http_provider(self):
+ """Test that HTTP provider ui metadata is loaded from provider info."""
+ pm = ProvidersManager()
+ pm.initialize_providers_hooks()
+
+ assert "http" in pm._field_behaviours
+ behaviour = pm._field_behaviours["http"]
+
+ assert "hidden_fields" in behaviour
+ assert "relabeling" in behaviour
+ assert "placeholders" in behaviour
+
+ def test_ui_metadata_loading_without_hook_import(self):
+ """Test that UI metadata loads from provider info without importing
hook classes."""
+ with patch("airflow.providers_manager.import_string") as mock_import:
+ pm = ProvidersManager()
+ pm.initialize_providers_hooks()
+
+ assert "http" in pm._field_behaviours
+
+ # assert that HttpHook was not imported during initialization,
which means yaml path was taken
+ assert len([call for call in mock_import.call_args_list if
"HttpHook" in str(call)]) == 0
diff --git a/contributing-docs/23_provider_hook_migration_to_yaml.rst
b/contributing-docs/23_provider_hook_migration_to_yaml.rst
new file mode 100644
index 00000000000..ce9a3fd8dd5
--- /dev/null
+++ b/contributing-docs/23_provider_hook_migration_to_yaml.rst
@@ -0,0 +1,110 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Provider hook to YAML Migration
+===============================
+
+We can now, redefine connection form metadata declaratively in
``provider.yaml`` of a provider instead of Python hook code,
+reducing dependencies and improving API server startup performance.
+
+**The outline for this document in GitHub is available at top-right corner
button (with 3-dots and 3 lines).**
+
+Background
+----------
+
+Previously, connection form UI metadata was defined in hook code of a provider
using:
+
+* ``get_connection_form_widgets()`` - Custom form fields
+* ``get_ui_field_behaviour()`` - Field customizations (hidden, relabeling,
placeholders)
+
+These methods required importing heavy dependencies ``flask_appbuilder`` and
``wtforms``, adding unnecessary dependencies
+to the API server and having API server to load all the provider hook code
just to display a static form.
+The new yaml approach allows metadata to be loaded without importing hook
classes.
+
+
+YAML Schema Structure
+---------------------
+
+Connection metadata is defined under ``connection-types`` in ``provider.yaml``
of a provider:
+
+ui-field-behaviour
+~~~~~~~~~~~~~~~~~~
+
+Customizations for standard connection fields:
+
+.. code-block:: yaml
+
+ ui-field-behaviour:
+ hidden-fields:
+ - schema
+ - extra
+ relabeling:
+ host: Registry URL
+ login: Username
+ placeholders:
+ port: '5432'
+
+conn-fields
+~~~~~~~~~~~
+
+Custom fields which will be stored within ``Connection.extra``. The ``schema``
property uses
+`JSON Schema <https://json-schema.org/>`_ to define field types and
validation. For details on
+supported field options, see
+`Use Params to Provide a Trigger UI Form
<https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/params.html#use-params-to-provide-a-trigger-ui-form>`_.
+
+.. code-block:: yaml
+
+ conn-fields:
+ keyfile_dict:
+ label: "Keyfile JSON"
+ description: "Service account JSON key"
+ schema:
+ type: string
+ format: password
+ project:
+ label: "Project Id"
+ schema:
+ type: string
+ default: "my-project"
+
+Migration Tool
+--------------
+
+The ``generate_yaml_format_for_hooks.py`` script extracts metadata from
existing Python hook code.
+Use the airflow virtual environment to run the script.
+
+Basic Usage
+~~~~~~~~~~~
+
+Extract from a provider:
+
+.. code-block:: bash
+
+ python scripts/generate_yaml_format_for_hooks.py --provider docker
+
+Extract from a specific hook (some providers can have many hooks):
+
+.. code-block:: bash
+
+ python scripts/generate_yaml_format_for_hooks.py \
+ --hook-class airflow.providers.docker.hooks.docker.DockerHook
+
+Update provider.yaml directly:
+
+.. code-block:: bash
+
+ python scripts/generate_yaml_format_for_hooks.py --provider docker
--update-yaml
diff --git
a/providers/fab/tests/unit/fab/www/views/test_connection_form_fields.py
b/providers/fab/tests/unit/fab/www/views/test_connection_form_fields.py
index c0f673a3f83..fc9271a791f 100644
--- a/providers/fab/tests/unit/fab/www/views/test_connection_form_fields.py
+++ b/providers/fab/tests/unit/fab/www/views/test_connection_form_fields.py
@@ -67,11 +67,19 @@ def
test_connection_form__add_widgets_prefix_backcompat(scenario, cleanup_provid
else:
raise ValueError("unexpected")
- provider_manager._add_widgets(
- package_name="abc",
- hook_class=MyHook,
- widgets=widgets,
- )
+ if hasattr(provider_manager, "_add_widgets_from_hook"):
+ provider_manager._add_widgets_from_hook(
+ package_name="abc",
+ hook_class=MyHook,
+ widgets=widgets,
+ )
+ else:
+ # backcompat for airflow < 3.2
+ provider_manager._add_widgets(
+ package_name="abc",
+ hook_class=MyHook,
+ widgets=widgets,
+ )
assert
provider_manager.connection_form_widgets["extra__test__my_param"].field ==
widget_field
@@ -88,11 +96,19 @@ def
test_connection_field_behaviors_placeholders_prefix(cleanup_providers_manage
}
provider_manager = ProvidersManager()
- provider_manager._add_customized_fields(
- package_name="abc",
- hook_class=MyHook,
- customized_fields=MyHook.get_ui_field_behaviour(),
- )
+ if hasattr(provider_manager, "_add_customized_fields_from_hook"):
+ provider_manager._add_customized_fields_from_hook(
+ package_name="abc",
+ hook_class=MyHook,
+ customized_fields=MyHook.get_ui_field_behaviour(),
+ )
+ else:
+ # backcompat for airflow < 3.2
+ provider_manager._add_customized_fields(
+ package_name="abc",
+ hook_class=MyHook,
+ customized_fields=MyHook.get_ui_field_behaviour(),
+ )
expected = {
"extra__test__abc": "hi", # prefix should be added, since `abc` is
not reserved
"extra__anything": "n/a", # no change since starts with extra
@@ -114,11 +130,20 @@ def
test_connection_form_widgets_fields_order(cleanup_providers_manager):
provider_manager = ProvidersManager()
provider_manager._connection_form_widgets = {}
- provider_manager._add_widgets(
- package_name="mock",
- hook_class=TestHook,
- widgets={f: BooleanField(lazy_gettext("Dummy param")) for f in
expected_field_names_order},
- )
+ widgets = {f: BooleanField(lazy_gettext("Dummy param")) for f in
expected_field_names_order}
+ if hasattr(provider_manager, "_add_widgets_from_hook"):
+ provider_manager._add_widgets_from_hook(
+ package_name="mock",
+ hook_class=TestHook,
+ widgets=widgets,
+ )
+ else:
+ # backcompat for airflow < 3.2
+ provider_manager._add_widgets(
+ package_name="mock",
+ hook_class=TestHook,
+ widgets=widgets,
+ )
actual_field_names_order = tuple(
key for key in provider_manager.connection_form_widgets.keys() if
key.startswith(field_prefix)
)
@@ -147,16 +172,31 @@ def
test_connection_form_widgets_fields_order_multiple_hooks(cleanup_providers_m
provider_manager = ProvidersManager()
provider_manager._connection_form_widgets = {}
- provider_manager._add_widgets(
- package_name="mock",
- hook_class=TestHook1,
- widgets={f"{field_prefix}{f}": BooleanField(lazy_gettext("Dummy
param")) for f in field_names_hook_1},
- )
- provider_manager._add_widgets(
- package_name="another_mock",
- hook_class=TestHook2,
- widgets={f"{field_prefix}{f}": BooleanField(lazy_gettext("Dummy
param")) for f in field_names_hook_2},
- )
+ widgets_1 = {f"{field_prefix}{f}": BooleanField(lazy_gettext("Dummy
param")) for f in field_names_hook_1}
+ widgets_2 = {f"{field_prefix}{f}": BooleanField(lazy_gettext("Dummy
param")) for f in field_names_hook_2}
+ if hasattr(provider_manager, "_add_widgets_from_hook"):
+ provider_manager._add_widgets_from_hook(
+ package_name="mock",
+ hook_class=TestHook1,
+ widgets=widgets_1,
+ )
+ provider_manager._add_widgets_from_hook(
+ package_name="another_mock",
+ hook_class=TestHook2,
+ widgets=widgets_2,
+ )
+ else:
+ # backcompat for airflow < 3.2
+ provider_manager._add_widgets(
+ package_name="mock",
+ hook_class=TestHook1,
+ widgets=widgets_1,
+ )
+ provider_manager._add_widgets(
+ package_name="another_mock",
+ hook_class=TestHook2,
+ widgets=widgets_2,
+ )
actual_field_names_order = tuple(
key for key in provider_manager.connection_form_widgets.keys() if
key.startswith(field_prefix)
)
diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml
index e1b0fd8faae..56cd3f1d0e6 100644
--- a/providers/google/provider.yaml
+++ b/providers/google/provider.yaml
@@ -1123,6 +1123,72 @@ transfers:
connection-types:
- hook-class-name:
airflow.providers.google.common.hooks.base_google.GoogleBaseHook
connection-type: google_cloud_platform
+ ui-field-behaviour:
+ hidden-fields: ["host", "schema", "login", "password", "port", "extra"]
+ relabeling: {}
+ placeholders: {}
+ conn-fields:
+ project:
+ label: "Project Id"
+ schema:
+ type: string
+ key_path:
+ label: "Keyfile Path"
+ schema:
+ type: string
+ keyfile_dict:
+ label: "Keyfile JSON"
+ schema:
+ type: string
+ format: password
+ credential_config_file:
+ label: "Credential Configuration File"
+ schema:
+ type: string
+ scope:
+ label: "Scopes (comma separated)"
+ schema:
+ type: string
+ key_secret_name:
+ label: "Keyfile Secret Name (in GCP Secret Manager)"
+ schema:
+ type: string
+ key_secret_project_id:
+ label: "Keyfile Secret Project Id (in GCP Secret Manager)"
+ schema:
+ type: string
+ num_retries:
+ label: "Number of Retries"
+ schema:
+ type: integer
+ minimum: 0
+ default: 5
+ impersonation_chain:
+ label: "Impersonation Chain"
+ schema:
+ type: string
+ idp_issuer_url:
+ label: "IdP Token Issue URL (Client Credentials Grant Flow)"
+ schema:
+ type: string
+ client_id:
+ label: "Client ID (Client Credentials Grant Flow)"
+ schema:
+ type: string
+ client_secret:
+ label: "Client Secret (Client Credentials Grant Flow)"
+ schema:
+ type: string
+ format: password
+ idp_extra_parameters:
+ label: "IdP Extra Request Parameters"
+ schema:
+ type: string
+ is_anonymous:
+ label: "Anonymous credentials (ignores all other settings)"
+ schema:
+ type: boolean
+ default: false
- hook-class-name:
airflow.providers.google.cloud.hooks.dataprep.GoogleDataprepHook
connection-type: dataprep
- hook-class-name:
airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook
diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py
b/providers/google/src/airflow/providers/google/get_provider_info.py
index 6c0a0a76d16..af4b9a8392a 100644
--- a/providers/google/src/airflow/providers/google/get_provider_info.py
+++ b/providers/google/src/airflow/providers/google/get_provider_info.py
@@ -1367,6 +1367,57 @@ def get_provider_info():
{
"hook-class-name":
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook",
"connection-type": "google_cloud_platform",
+ "ui-field-behaviour": {
+ "hidden-fields": ["host", "schema", "login", "password",
"port", "extra"],
+ "relabeling": {},
+ "placeholders": {},
+ },
+ "conn-fields": {
+ "project": {"label": "Project Id", "schema": {"type":
"string"}},
+ "key_path": {"label": "Keyfile Path", "schema": {"type":
"string"}},
+ "keyfile_dict": {
+ "label": "Keyfile JSON",
+ "schema": {"type": "string", "format": "password"},
+ },
+ "credential_config_file": {
+ "label": "Credential Configuration File",
+ "schema": {"type": "string"},
+ },
+ "scope": {"label": "Scopes (comma separated)", "schema":
{"type": "string"}},
+ "key_secret_name": {
+ "label": "Keyfile Secret Name (in GCP Secret Manager)",
+ "schema": {"type": "string"},
+ },
+ "key_secret_project_id": {
+ "label": "Keyfile Secret Project Id (in GCP Secret
Manager)",
+ "schema": {"type": "string"},
+ },
+ "num_retries": {
+ "label": "Number of Retries",
+ "schema": {"type": "integer", "minimum": 0, "default":
5},
+ },
+ "impersonation_chain": {"label": "Impersonation Chain",
"schema": {"type": "string"}},
+ "idp_issuer_url": {
+ "label": "IdP Token Issue URL (Client Credentials
Grant Flow)",
+ "schema": {"type": "string"},
+ },
+ "client_id": {
+ "label": "Client ID (Client Credentials Grant Flow)",
+ "schema": {"type": "string"},
+ },
+ "client_secret": {
+ "label": "Client Secret (Client Credentials Grant
Flow)",
+ "schema": {"type": "string", "format": "password"},
+ },
+ "idp_extra_parameters": {
+ "label": "IdP Extra Request Parameters",
+ "schema": {"type": "string"},
+ },
+ "is_anonymous": {
+ "label": "Anonymous credentials (ignores all other
settings)",
+ "schema": {"type": "boolean", "default": False},
+ },
+ },
},
{
"hook-class-name":
"airflow.providers.google.cloud.hooks.dataprep.GoogleDataprepHook",
diff --git a/providers/http/provider.yaml b/providers/http/provider.yaml
index a5ad9baa1db..d100c3094e8 100644
--- a/providers/http/provider.yaml
+++ b/providers/http/provider.yaml
@@ -115,3 +115,8 @@ triggers:
connection-types:
- hook-class-name: airflow.providers.http.hooks.http.HttpHook
connection-type: http
+ ui-field-behaviour:
+ hidden-fields: []
+ relabeling: {}
+ placeholders: {}
+ conn-fields: {}
diff --git a/providers/http/src/airflow/providers/http/get_provider_info.py
b/providers/http/src/airflow/providers/http/get_provider_info.py
index 37b48c286f0..1c95246e12d 100644
--- a/providers/http/src/airflow/providers/http/get_provider_info.py
+++ b/providers/http/src/airflow/providers/http/get_provider_info.py
@@ -61,6 +61,11 @@ def get_provider_info():
}
],
"connection-types": [
- {"hook-class-name": "airflow.providers.http.hooks.http.HttpHook",
"connection-type": "http"}
+ {
+ "hook-class-name":
"airflow.providers.http.hooks.http.HttpHook",
+ "connection-type": "http",
+ "ui-field-behaviour": {"hidden-fields": [], "relabeling": {},
"placeholders": {}},
+ "conn-fields": {},
+ }
],
}
diff --git a/providers/smtp/provider.yaml b/providers/smtp/provider.yaml
index a545c762132..f617c9a1978 100644
--- a/providers/smtp/provider.yaml
+++ b/providers/smtp/provider.yaml
@@ -81,6 +81,10 @@ hooks:
connection-types:
- hook-class-name: airflow.providers.smtp.hooks.smtp.SmtpHook
connection-type: smtp
+ ui-field-behaviour:
+ hidden-fields: ['schema', 'extra']
+ relabeling: {}
+ placeholders: {}
notifications:
- airflow.providers.smtp.notifications.smtp.SmtpNotifier
diff --git a/providers/smtp/src/airflow/providers/smtp/get_provider_info.py
b/providers/smtp/src/airflow/providers/smtp/get_provider_info.py
index cffe1895670..bcfaaffc53e 100644
--- a/providers/smtp/src/airflow/providers/smtp/get_provider_info.py
+++ b/providers/smtp/src/airflow/providers/smtp/get_provider_info.py
@@ -47,7 +47,15 @@ def get_provider_info():
}
],
"connection-types": [
- {"hook-class-name": "airflow.providers.smtp.hooks.smtp.SmtpHook",
"connection-type": "smtp"}
+ {
+ "hook-class-name":
"airflow.providers.smtp.hooks.smtp.SmtpHook",
+ "connection-type": "smtp",
+ "ui-field-behaviour": {
+ "hidden-fields": ["schema", "extra"],
+ "relabeling": {},
+ "placeholders": {},
+ },
+ }
],
"notifications":
["airflow.providers.smtp.notifications.smtp.SmtpNotifier"],
}
diff --git a/scripts/ci/prek/check_core_imports_in_shared.py
b/scripts/ci/prek/check_core_imports_in_shared.py
new file mode 100644
index 00000000000..f7b16153410
--- /dev/null
+++ b/scripts/ci/prek/check_core_imports_in_shared.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10,<3.11"
+# dependencies = [
+# "rich>=13.6.0",
+# ]
+# ///
+from __future__ import annotations
+
+import argparse
+import ast
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent.resolve()))
+from common_prek_utils import console
+
+
+def check_file_for_prohibited_imports(file_path: Path) -> list[tuple[int,
str]]:
+ """
+ Check file for airflow-core and airflow.sdk imports.
+ Shared libraries should not depend on either core or task-sdk.
+ Returns list of (line_num, import_statement).
+ """
+ try:
+ source = file_path.read_text(encoding="utf-8")
+ tree = ast.parse(source, filename=str(file_path))
+ except (OSError, UnicodeDecodeError, SyntaxError):
+ return []
+
+ violations = []
+
+ for node in ast.walk(tree):
+ # Check `from airflow.x import y` statements
+ if isinstance(node, ast.ImportFrom):
+ if node.module and node.module.startswith("airflow."):
+ # Allow airflow_shared imports (which show as airflow._shared
at runtime)
+ if not (
+ node.module.startswith("airflow_shared") or
node.module.startswith("airflow._shared")
+ ):
+ import_names = ", ".join(alias.name for alias in
node.names)
+ statement = f"from {node.module} import {import_names}"
+ violations.append((node.lineno, statement))
+
+ # Check `import airflow.x` statements
+ elif isinstance(node, ast.Import):
+ for alias in node.names:
+ if alias.name.startswith("airflow."):
+ # Allow airflow_shared imports (which show as
airflow._shared at runtime)
+ if not (
+ alias.name.startswith("airflow_shared") or
alias.name.startswith("airflow._shared")
+ ):
+ statement = f"import {alias.name}"
+ if alias.asname:
+ statement += f" as {alias.asname}"
+ violations.append((node.lineno, statement))
+
+ return violations
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Check for core/task-sdk
imports in shared library files")
+ parser.add_argument("files", nargs="*", help="Files to check")
+ args = parser.parse_args()
+
+ if not args.files:
+ return
+
+ total_violations = 0
+
+ for file_path in [Path(f) for f in args.files]:
+ violations = check_file_for_prohibited_imports(file_path)
+ if violations:
+ console.print(f"[red]{file_path}[/red]:")
+ for line_num, statement in violations:
+ console.print(f" [yellow]Line {line_num}[/yellow]:
{statement}")
+ total_violations += len(violations)
+
+ if total_violations:
+ console.print()
+ console.print(f"[red]Found {total_violations} prohibited import(s) in
shared library files[/red]")
+ console.print("[yellow]Shared libraries must not import from
airflow-core or airflow.sdk[/yellow]")
+ console.print("[yellow]Only airflow_shared (airflow._shared at
runtime) imports are allowed[/yellow]")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
+ sys.exit(0)
diff --git a/scripts/tools/generate_yaml_format_for_hooks.py
b/scripts/tools/generate_yaml_format_for_hooks.py
new file mode 100755
index 00000000000..ecb3c112f1d
--- /dev/null
+++ b/scripts/tools/generate_yaml_format_for_hooks.py
@@ -0,0 +1,262 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Generate conn-fields and ui-field-behaviour for provider.yaml from Python
hooks.
+
+This script requires the Airflow development environment with all workspace
+modules available.
+
+Usage:
+ python scripts/generate_yaml_format_for_hooks.py --provider google
+ python scripts/generate_yaml_format_for_hooks.py --provider google
--update-yaml
+ python scripts/generate_yaml_format_for_hooks.py --hook-class
airflow.providers.http.hooks.http.HttpHook
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+from pathlib import Path
+
+import yaml
+from airflow_breeze.utils.console import get_console
+
+AIRFLOW_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(AIRFLOW_ROOT))
+sys.path.insert(0, str(AIRFLOW_ROOT / "dev" / "breeze" / "src"))
+
+console = get_console()
+
+
+def extract_conn_fields(widgets, connection_type):
+ """Get conn-fields from hook widgets."""
+ conn_fields = {}
+ prefix = f"extra__{connection_type}__"
+
+ for field_key, field_widget in widgets.items():
+ field_name = field_key[len(prefix) :] if field_key.startswith(prefix)
else field_key
+
+ if not hasattr(field_widget, "param"):
+ continue
+
+ field_data = field_widget.param.dump()
+ schema = field_data.get("schema", {}).copy()
+ label = schema.pop("title", field_name.replace("_", " ").title())
+
+ field_class_name = getattr(field_widget, "field_class", None)
+ is_password_field = field_class_name and "Password" in
field_class_name.__name__
+
+ if is_password_field and schema.get("format") != "password":
+ schema["format"] = "password"
+
+ if field_data.get("value") is not None:
+ schema["default"] = field_data["value"]
+
+ yaml_field = {"label": label, "schema": schema}
+
+ if field_data.get("description"):
+ yaml_field["description"] = field_data["description"]
+
+ conn_fields[field_name] = yaml_field
+
+ return conn_fields
+
+
+def extract_ui_behaviour(hook_class):
+ """Get ui-field-behaviour from hook."""
+ if not hasattr(hook_class, "get_ui_field_behaviour"):
+ return None
+ if "get_ui_field_behaviour" not in hook_class.__dict__:
+ return None
+
+ behaviour = hook_class.get_ui_field_behaviour()
+ if not behaviour:
+ return None
+
+ yaml_behaviour = {}
+ if behaviour.get("hidden_fields"):
+ yaml_behaviour["hidden-fields"] = behaviour["hidden_fields"]
+ if behaviour.get("relabeling"):
+ yaml_behaviour["relabeling"] = behaviour["relabeling"]
+ if behaviour.get("placeholders"):
+ yaml_behaviour["placeholders"] = behaviour["placeholders"]
+
+ return yaml_behaviour or None
+
+
+def extract_from_hook(hook_class_name):
+ """Get metadata from hook class."""
+ from airflow._shared.module_loading import import_string
+
+ try:
+ hook_class = import_string(hook_class_name)
+ except Exception as e:
+ console.print(f"Error importing {hook_class_name}: {e}")
+ return None, None, None
+
+ connection_type = getattr(hook_class, "conn_type", None)
+ if not connection_type:
+ return None, None, None
+
+ conn_fields = None
+ if (
+ hasattr(hook_class, "get_connection_form_widgets")
+ and "get_connection_form_widgets" in hook_class.__dict__
+ ):
+ try:
+ widgets = hook_class.get_connection_form_widgets()
+ if widgets:
+ conn_fields = extract_conn_fields(widgets, connection_type)
+ except Exception as e:
+ console.print(f"Error extracting widgets from {hook_class_name}:
{e}")
+
+ ui_behaviour = extract_ui_behaviour(hook_class)
+
+ return conn_fields, ui_behaviour, connection_type
+
+
+def find_provider_yaml(provider_name):
+ """Find provider.yaml file."""
+ path = AIRFLOW_ROOT / "providers" / provider_name / "provider.yaml"
+ return path if path.exists() else None
+
+
+def get_hooks_from_provider(provider_name):
+ """Get hook classes from provider.yaml."""
+ provider_yaml = find_provider_yaml(provider_name)
+ if not provider_yaml:
+ return []
+
+ with open(provider_yaml) as f:
+ data = yaml.safe_load(f)
+
+ return [conn["hook-class-name"] for conn in data.get("connection-types",
[]) if "hook-class-name" in conn]
+
+
+def update_provider_yaml(provider_name, hook_metadata):
+ """Update provider.yaml with extracted metadata."""
+ provider_yaml = find_provider_yaml(provider_name)
+ if not provider_yaml:
+ console.print(f"Provider yaml not found for {provider_name}")
+ return False
+
+ with open(provider_yaml) as f:
+ data = yaml.safe_load(f)
+
+ updated = False
+ for conn in data.get("connection-types", []):
+ hook_class = conn.get("hook-class-name")
+ if hook_class not in hook_metadata:
+ continue
+
+ conn_fields, ui_behaviour, _ = hook_metadata[hook_class]
+
+ if ui_behaviour:
+ conn["ui-field-behaviour"] = ui_behaviour
+ updated = True
+
+ if conn_fields:
+ conn["conn-fields"] = conn_fields
+ updated = True
+
+ if updated:
+ with open(provider_yaml, "w") as f:
+ yaml.dump(data, f, default_flow_style=False, sort_keys=False,
allow_unicode=True)
+ console.print(f"Updated {provider_yaml}")
+ return True
+
+ return False
+
+
+def format_output(hook_class_name, conn_fields, ui_behaviour, connection_type):
+ """Format extracted metadata for display."""
+ lines = [f"# {hook_class_name} (connection_type: {connection_type})"]
+
+ if ui_behaviour:
+ lines.append("ui-field-behaviour:")
+ ui_yaml = yaml.dump(ui_behaviour, default_flow_style=False,
sort_keys=False)
+ lines.extend(f" {line}" for line in ui_yaml.rstrip().split("\n"))
+ lines.append("")
+
+ if conn_fields:
+ lines.append("conn-fields:")
+ fields_yaml = yaml.dump(conn_fields, default_flow_style=False,
sort_keys=False)
+ lines.extend(f" {line}" for line in fields_yaml.rstrip().split("\n"))
+
+ return "\n".join(lines)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate yaml from hook
metadata")
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument("--provider", help="Provider name (e.g., google,
docker)")
+ group.add_argument("--hook-class", help="Full hook class name")
+ parser.add_argument(
+ "--update-yaml", action="store_true", help="Update provider.yaml (only
with --provider)"
+ )
+
+ args = parser.parse_args()
+
+ if args.hook_class:
+ hooks = [args.hook_class]
+ provider_name = None
+ else:
+ hooks = get_hooks_from_provider(args.provider)
+ provider_name = args.provider
+
+ if not hooks:
+ console.print(f"No hooks found for provider {args.provider}")
+ return 1
+
+ hook_metadata = {}
+ output_lines = []
+
+ for hook_class_name in hooks:
+ console.print(f"\nProcessing: {hook_class_name}")
+ conn_fields, ui_behaviour, connection_type =
extract_from_hook(hook_class_name)
+
+ if conn_fields or ui_behaviour:
+ hook_metadata[hook_class_name] = (conn_fields, ui_behaviour,
connection_type)
+ output = format_output(hook_class_name, conn_fields, ui_behaviour,
connection_type)
+ output_lines.append(output)
+ else:
+ console.print("No metadata found")
+
+ if not output_lines:
+ console.print("\nNo metadata extracted")
+ return 1
+
+ full_output = "\n" * 2 + " ".join(output_lines)
+
+ if args.update_yaml:
+ if not provider_name:
+ console.print("\n--update-yaml only works with --provider")
+ return 1
+ if update_provider_yaml(provider_name, hook_metadata):
+ console.print("\nProvider yaml updated")
+ else:
+ console.print("\nFailed to update provider yaml")
+ else:
+ console.print(full_output)
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())