ashb commented on code in PR #45300: URL: https://github.com/apache/airflow/pull/45300#discussion_r1988811690
########## airflow/cli/api/cli_api_client.py: ########## Review Comment: Nit: can we call this file `airflow/cli/api/client.py` -- repeating cli api twice is just noise I feel ########## airflow/cli/api/client.py: ########## @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import os +import sys +from typing import TYPE_CHECKING, Any + +import httpx +import keyring +import rich +import structlog +from platformdirs import user_config_path +from uuid6 import uuid7 + +from airflow.cli.api.operations import ( + AssetsOperations, + BackfillsOperations, + ConfigOperations, + ConnectionsOperations, + DagOperations, + DagRunOperations, + JobsOperations, + PoolsOperations, + ProvidersOperations, + ServerResponseError, + VariablesOperations, + VersionOperations, +) +from airflow.version import version + +if TYPE_CHECKING: + # # methodtools doesn't have typestubs, so give a stub + def lru_cache(maxsize: int | None = 128): + def wrapper(f): + return f + + return wrapper +else: + from methodtools import lru_cache + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "Credentials", +] + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error ", extra=dict(err.response.json())) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Credentials for the API +class Credentials: + """Credentials for the API.""" + + api_url: str | None + api_token: str | None + api_environment: str + + def __init__( + self, + api_url: str | None = None, + api_token: str | None = None, + api_environment: str = "production", + ): + self.api_url = api_url + self.api_token = api_token + self.api_environment = os.getenv("APACHE_AIRFLOW_CLI_ENVIRONMENT") or api_environment + + @property + def input_cli_config_file(self) -> str: + """Generate path and always generate that path but let's not world readable.""" Review Comment: I dont think we do anything here about "world readable", so remove that bit? ########## airflow/cli/api/client.py: ########## @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import os +import sys +from typing import TYPE_CHECKING, Any + +import httpx +import keyring +import rich +import structlog +from platformdirs import user_config_path +from uuid6 import uuid7 + +from airflow.cli.api.operations import ( + AssetsOperations, + BackfillsOperations, + ConfigOperations, + ConnectionsOperations, + DagOperations, + DagRunOperations, + JobsOperations, + PoolsOperations, + ProvidersOperations, + ServerResponseError, + VariablesOperations, + VersionOperations, +) +from airflow.version import version + +if TYPE_CHECKING: + # # methodtools doesn't have typestubs, so give a stub + def lru_cache(maxsize: int | None = 128): + def wrapper(f): + return f + + return wrapper +else: + from methodtools import lru_cache + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "Credentials", +] + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error ", extra=dict(err.response.json())) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Credentials for the API +class Credentials: + """Credentials for the API.""" + + api_url: str | None + api_token: str | None + api_environment: str + + def __init__( + self, + api_url: str | None = None, + api_token: str | None = None, + api_environment: str = "production", + ): + self.api_url = api_url + self.api_token = api_token + self.api_environment = os.getenv("APACHE_AIRFLOW_CLI_ENVIRONMENT") or api_environment Review Comment: No other env vars have the `APACHE_` prefix ```suggestion self.api_environment = os.getenv("AIRFLOW_CLI_ENVIRONMENT") or api_environment ``` ########## airflow/cli/commands/remote_commands/auth_command.py: ########## @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +import sys + +import rich + +from airflow.cli.api.cli_api_client import Credentials +from airflow.utils import cli as cli_utils + + +@cli_utils.action_cli +def login(args) -> None: + """Login to a provider.""" + if not args.api_token and not os.environ.get("APACHE_AIRFLOW_CLI_TOKEN"): + # Exit + rich.print("[red]No token found.") + rich.print( + "[green]Please pass:[/green] [blue]--api-token[/blue] or set " + "[blue]APACHE_AIRFLOW_CLI_TOKEN[/blue] environment variable to login." + ) + sys.exit(1) + Credentials( + api_url=args.api_url, + api_token=args.api_token or os.getenv("APACHE_AIRFLOW_CLI_TOKEN"), Review Comment: ```suggestion api_token=token, ``` ########## airflow/cli/api/client.py: ########## @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import os +import sys +from typing import TYPE_CHECKING, Any + +import httpx +import keyring +import rich +import structlog +from platformdirs import user_config_path +from uuid6 import uuid7 + +from airflow.cli.api.operations import ( + AssetsOperations, + BackfillsOperations, + ConfigOperations, + ConnectionsOperations, + DagOperations, + DagRunOperations, + JobsOperations, + PoolsOperations, + ProvidersOperations, + ServerResponseError, + VariablesOperations, + VersionOperations, +) +from airflow.version import version + +if TYPE_CHECKING: + # # methodtools doesn't have typestubs, so give a stub + def lru_cache(maxsize: int | None = 128): + def wrapper(f): + return f + + return wrapper +else: + from methodtools import lru_cache + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "Credentials", +] + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error ", extra=dict(err.response.json())) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Credentials for the API +class Credentials: + """Credentials for the API.""" + + api_url: str | None + api_token: str | None + api_environment: str + + def __init__( + self, + api_url: str | None = None, + api_token: str | None = None, + api_environment: str = "production", + ): + self.api_url = api_url + self.api_token = api_token + self.api_environment = os.getenv("APACHE_AIRFLOW_CLI_ENVIRONMENT") or api_environment + + @property + def input_cli_config_file(self) -> str: + """Generate path and always generate that path but let's not world readable.""" + return f"{self.api_environment}.json" + + def save(self): + """Save the credentials to keyring and URL to disk as a file.""" + default_config_dir = user_config_path("airflow", "Apache Software Foundation") + if not os.path.exists(default_config_dir): + os.makedirs(default_config_dir) + with open(os.path.join(default_config_dir, self.input_cli_config_file), "w") as f: + json.dump({"api_url": self.api_url}, f) + keyring.set_password("airflow-cli", f"api_token-{self.api_environment}", self.api_token) + + def load(self) -> Credentials: + """Load the credentials from keyring and URL from disk file.""" + default_config_dir = user_config_path("airflow", "Apache Software Foundation") + if os.path.exists(default_config_dir): + with open(os.path.join(default_config_dir, self.input_cli_config_file)) as f: + credentials = json.load(f) + self.api_url = credentials["api_url"] + self.api_token = keyring.get_password("airflow-cli", f"api_token-{self.api_environment}") + return self + else: + rich.print("[red]No credentials found.") + rich.print("[green]Please run: [blue]airflow auth login") + sys.exit(1) + + +class BearerAuth(httpx.Auth): + def __init__(self, token: str): + self.token: str = token + + def auth_flow(self, request: httpx.Request): + if self.token: + request.headers["Authorization"] = "Bearer " + self.token + yield request + + +class Client(httpx.Client): + """Client for the Airflow REST API.""" + + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): + if (not base_url) ^ dry_run: + raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") + auth = BearerAuth(token) + + if dry_run: + # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a + # real client, but just don't make any HTTP requests + kwargs["base_url"] = "dry-run://server" + else: + kwargs["base_url"] = f"{base_url}/public" Review Comment: ```suggestion def __init__(self, *, base_url: str, token: str, **kwargs: Any): auth = BearerAuth(token) kwargs["base_url"] = f"{base_url}/public" ``` i think ########## airflow/cli/commands/remote_commands/auth_command.py: ########## @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +import sys + +import rich + +from airflow.cli.api.cli_api_client import Credentials +from airflow.utils import cli as cli_utils + + +@cli_utils.action_cli +def login(args) -> None: + """Login to a provider.""" + if not args.api_token and not os.environ.get("APACHE_AIRFLOW_CLI_TOKEN"): + # Exit + rich.print("[red]No token found.") + rich.print( + "[green]Please pass:[/green] [blue]--api-token[/blue] or set " + "[blue]APACHE_AIRFLOW_CLI_TOKEN[/blue] environment variable to login." Review Comment: ```suggestion "[blue]AIRFLOW_CLI_TOKEN[/blue] environment variable to login." ``` ########## airflow/cli/api/client.py: ########## @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import os +import sys +from typing import TYPE_CHECKING, Any + +import httpx +import keyring +import rich +import structlog +from platformdirs import user_config_path +from uuid6 import uuid7 + +from airflow.cli.api.operations import ( + AssetsOperations, + BackfillsOperations, + ConfigOperations, + ConnectionsOperations, + DagOperations, + DagRunOperations, + JobsOperations, + PoolsOperations, + ProvidersOperations, + ServerResponseError, + VariablesOperations, + VersionOperations, +) +from airflow.version import version + +if TYPE_CHECKING: + # # methodtools doesn't have typestubs, so give a stub + def lru_cache(maxsize: int | None = 128): + def wrapper(f): + return f + + return wrapper +else: + from methodtools import lru_cache + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "Credentials", +] + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error ", extra=dict(err.response.json())) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Credentials for the API +class Credentials: + """Credentials for the API.""" + + api_url: str | None + api_token: str | None + api_environment: str + + def __init__( + self, + api_url: str | None = None, + api_token: str | None = None, + api_environment: str = "production", + ): + self.api_url = api_url + self.api_token = api_token + self.api_environment = os.getenv("APACHE_AIRFLOW_CLI_ENVIRONMENT") or api_environment + + @property + def input_cli_config_file(self) -> str: + """Generate path and always generate that path but let's not world readable.""" + return f"{self.api_environment}.json" + + def save(self): + """Save the credentials to keyring and URL to disk as a file.""" + default_config_dir = user_config_path("airflow", "Apache Software Foundation") + if not os.path.exists(default_config_dir): + os.makedirs(default_config_dir) + with open(os.path.join(default_config_dir, self.input_cli_config_file), "w") as f: + json.dump({"api_url": self.api_url}, f) + keyring.set_password("airflow-cli", f"api_token-{self.api_environment}", self.api_token) + + def load(self) -> Credentials: + """Load the credentials from keyring and URL from disk file.""" + default_config_dir = user_config_path("airflow", "Apache Software Foundation") + if os.path.exists(default_config_dir): + with open(os.path.join(default_config_dir, self.input_cli_config_file)) as f: + credentials = json.load(f) + self.api_url = credentials["api_url"] + self.api_token = keyring.get_password("airflow-cli", f"api_token-{self.api_environment}") + return self + else: + rich.print("[red]No credentials found.") + rich.print("[green]Please run: [blue]airflow auth login") + sys.exit(1) Review Comment: We shouldn't be printing here as this makes it harder to use this "properly". This should raise something, and yhe print and exit be closer to the cli layer ########## airflow/cli/api/cli_api_client.py: ########## Review Comment: Wait, we have both. Merge these fns in to the other file, its not clear from the outside which one to use otherwise ########## airflow/cli/api/operations.py: ########## @@ -0,0 +1,684 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime +import sys +from typing import TYPE_CHECKING, Any + +import httpx +import rich +import structlog + +from airflow.cli.api.datamodels._generated import ( + AssetAliasCollectionResponse, + AssetAliasResponse, + AssetCollectionResponse, + AssetResponse, + BackfillPostBody, + BackfillResponse, + Config, + ConnectionBody, + ConnectionBulkActionResponse, + ConnectionBulkBody, + ConnectionCollectionResponse, + ConnectionResponse, + ConnectionTestResponse, + DAGDetailsResponse, + DAGResponse, + DAGRunCollectionResponse, + DAGRunResponse, + JobCollectionResponse, + PoolBulkActionResponse, + PoolBulkBody, + PoolCollectionResponse, + PoolPatchBody, + PoolPostBody, + PoolResponse, + ProviderCollectionResponse, + TriggerDAGRunPostBody, + VariableBody, + VariableBulkActionResponse, + VariableBulkBody, + VariableCollectionResponse, + VariableResponse, + VersionInfo, +) + +if TYPE_CHECKING: + from airflow.cli.api.client import Client + from airflow.utils.state import DagRunState + +log = structlog.get_logger(logger_name=__name__) + + +# Generic Server Response Error +class ServerResponseError(httpx.HTTPStatusError): + """Server response error (Generic).""" + + def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): + super().__init__(message, request=request, response=response) + + # def_ + + @classmethod + def from_response(cls, response: httpx.Response) -> ServerResponseError | None: + if response.status_code < 400: + return None + + if response.headers.get("content-type") != "application/json": + return None + + if 400 <= response.status_code < 500: + response.read() + return cls( + message=f"Client error message: {response.json()}", + request=response.request, + response=response, + ) + + msg = response.json() + + self = cls(message=msg, request=response.request, response=response) + return self + + +# Decorator to apply methods to all operations, this is initiated at __init_subclass__ on BaseOperations +SERVER_CONNECTION_REFUSED_ERROR: str = "API Server is not running, please contact your administrator." + + +def _check_flag_and_exit_if_server_response_error(func): + """Return decorator to check for ServerResponseError and exit if the server is not running.""" + + def _exit_if_server_response_error(response: Any | ServerResponseError): + if isinstance(response, ServerResponseError): + rich.print(f"[bold red]Error:[/bold red] {response.response.json()}") + sys.exit(1) + return response + + def wrapped(self, *args, **kwargs): + try: + if self.exit_in_error: + return _exit_if_server_response_error(response=func(self, *args, **kwargs)) + else: + return func(self, *args, **kwargs) + except httpx.ConnectError as e: + rich.print(f"error: {e}") + rich.print(f"[bold red]{SERVER_CONNECTION_REFUSED_ERROR}[/bold red]") + sys.exit(1) Review Comment: Similar here, this shouldnt be handled so "deep" in the stack ########## airflow/cli/cli_config.py: ########## @@ -1496,15 +1527,14 @@ class GroupCommand(NamedTuple): name="test", help="Test a connection", func=lazy_load_command("airflow.cli.commands.remote_commands.connection_command.connections_test"), - args=(ARG_CONN_ID, ARG_VERBOSE), + args=(ARG_CONN_ID, ARG_CONN_TYPE_POSITIONAL, ARG_VERBOSE), Review Comment: Why didn't we need this before? Was it just missed? ########## airflow/cli/commands/remote_commands/connection_command.py: ########## @@ -64,35 +70,31 @@ def _connection_mapper(conn: Connection) -> dict[str, Any]: @suppress_logs_and_warning @providers_configuration_loaded -def connections_get(args): +@provide_cli_api_client +def connections_get(args, cli_api_client=NEW_CLI_API_CLIENT): """Get a connection.""" - try: - conn = BaseHook.get_connection(args.conn_id) - except AirflowNotFoundException: - raise SystemExit("Connection not found.") Review Comment: I think something like this is the error handling pattern we should keep ########## hatch_build.py: ########## @@ -225,6 +225,8 @@ "itsdangerous>=2.0", "jinja2>=3.0.0", "jsonschema>=4.18.0", + # Added for CLI + "keyring>=25.6.0", Review Comment: Do we recommend/require anywhere the native os backends for this? ########## hatch_build.py: ########## @@ -238,6 +240,8 @@ "opentelemetry-exporter-otlp>=1.24.0", "packaging>=23.2", "pathspec>=0.9.0", + # Added for CLI Review Comment: ```suggestion ``` ########## airflow/cli/commands/remote_commands/auth_command.py: ########## @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +import sys + +import rich + +from airflow.cli.api.cli_api_client import Credentials +from airflow.utils import cli as cli_utils + + +@cli_utils.action_cli +def login(args) -> None: + """Login to a provider.""" + if not args.api_token and not os.environ.get("APACHE_AIRFLOW_CLI_TOKEN"): Review Comment: ```suggestion if not (token := args.api_token or os.environ.get("AIRFLOW_CLI_TOKEN")): ``` ########## hatch_build.py: ########## @@ -225,6 +225,8 @@ "itsdangerous>=2.0", "jinja2>=3.0.0", "jsonschema>=4.18.0", + # Added for CLI Review Comment: ```suggestion ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
