This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 09d478ec67 Allow 'airflow variables export' to print to stdout (#33279)
09d478ec67 is described below

commit 09d478ec671f8017294d4e15d75db1f40b8cc404
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Aug 11 17:02:48 2023 +0800

    Allow 'airflow variables export' to print to stdout (#33279)
    
    Co-authored-by: vedantlodha <[email protected]>
---
 airflow/cli/cli_config.py                  | 10 ++++-
 airflow/cli/commands/connection_command.py | 54 ++++++++++++--------------
 airflow/cli/commands/variable_command.py   | 61 +++++++++++++-----------------
 airflow/cli/utils.py                       | 33 ++++++++++++++++
 4 files changed, 92 insertions(+), 66 deletions(-)

diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index f2248efe5e..8e48a55351 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -543,7 +543,11 @@ ARG_DEFAULT = Arg(
 ARG_DESERIALIZE_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", 
action="store_true")
 ARG_SERIALIZE_JSON = Arg(("-j", "--json"), help="Serialize JSON variable", 
action="store_true")
 ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file")
-ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file")
+ARG_VAR_EXPORT = Arg(
+    ("file",),
+    help="Export all variables to JSON file",
+    type=argparse.FileType("w", encoding="UTF-8"),
+)
 
 # kerberos
 ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?")
@@ -1521,6 +1525,10 @@ VARIABLES_COMMANDS = (
     ActionCommand(
         name="export",
         help="Export all variables",
+        description=(
+            "All variables can be exported in STDOUT using the following 
command:\n"
+            "airflow variables export -\n"
+        ),
         
func=lazy_load_command("airflow.cli.commands.variable_command.variables_export"),
         args=(ARG_VAR_EXPORT, ARG_VERBOSE),
     ),
diff --git a/airflow/cli/commands/connection_command.py 
b/airflow/cli/commands/connection_command.py
index 3a6909bc15..f68dd83f02 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -17,7 +17,6 @@
 """Connection sub-commands."""
 from __future__ import annotations
 
-import io
 import json
 import os
 import sys
@@ -30,6 +29,7 @@ from sqlalchemy import select
 from sqlalchemy.orm import exc
 
 from airflow.cli.simple_table import AirflowConsole
+from airflow.cli.utils import is_stdout
 from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.exceptions import AirflowNotFoundException
@@ -138,10 +138,6 @@ def _format_connections(conns: list[Connection], 
file_format: str, serialization
     return json.dumps(connections_dict)
 
 
-def _is_stdout(fileio: io.TextIOWrapper) -> bool:
-    return fileio.name == "<stdout>"
-
-
 def _valid_uri(uri: str) -> bool:
     """Check if a URI is valid, by checking if scheme (conn_type) provided."""
     return urlsplit(uri).scheme != ""
@@ -171,32 +167,30 @@ def connections_export(args):
     if args.format or args.file_format:
         provided_file_format = f".{(args.format or args.file_format).lower()}"
 
-    file_is_stdout = _is_stdout(args.file)
-    if file_is_stdout:
-        filetype = provided_file_format or default_format
-    elif provided_file_format:
-        filetype = provided_file_format
-    else:
-        filetype = Path(args.file.name).suffix
-        filetype = filetype.lower()
-        if filetype not in file_formats:
-            raise SystemExit(
-                f"Unsupported file format. The file must have the extension 
{', '.join(file_formats)}."
-            )
-
-    if args.serialization_format and not filetype == ".env":
-        raise SystemExit("Option `--serialization-format` may only be used 
with file type `env`.")
-
-    with create_session() as session:
-        connections = 
session.scalars(select(Connection).order_by(Connection.conn_id)).all()
-
-    msg = _format_connections(
-        conns=connections,
-        file_format=filetype,
-        serialization_format=args.serialization_format or "uri",
-    )
-
     with args.file as f:
+        if file_is_stdout := is_stdout(f):
+            filetype = provided_file_format or default_format
+        elif provided_file_format:
+            filetype = provided_file_format
+        else:
+            filetype = Path(args.file.name).suffix.lower()
+            if filetype not in file_formats:
+                raise SystemExit(
+                    f"Unsupported file format. The file must have the 
extension {', '.join(file_formats)}."
+                )
+
+        if args.serialization_format and not filetype == ".env":
+            raise SystemExit("Option `--serialization-format` may only be used 
with file type `env`.")
+
+        with create_session() as session:
+            connections = 
session.scalars(select(Connection).order_by(Connection.conn_id)).all()
+
+        msg = _format_connections(
+            conns=connections,
+            file_format=filetype,
+            serialization_format=args.serialization_format or "uri",
+        )
+
         f.write(msg)
 
     if file_is_stdout:
diff --git a/airflow/cli/commands/variable_command.py 
b/airflow/cli/commands/variable_command.py
index 34b46530d5..78275666b6 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -20,11 +20,13 @@ from __future__ import annotations
 
 import json
 import os
+import sys
 from json import JSONDecodeError
 
 from sqlalchemy import select
 
 from airflow.cli.simple_table import AirflowConsole
+from airflow.cli.utils import is_stdout
 from airflow.models import Variable
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import suppress_logs_and_warning
@@ -76,44 +78,30 @@ def variables_delete(args):
 @providers_configuration_loaded
 def variables_import(args):
     """Imports variables from a given file."""
-    if os.path.exists(args.file):
-        _import_helper(args.file)
-    else:
+    if not os.path.exists(args.file):
         raise SystemExit("Missing variables file.")
+    with open(args.file) as varfile:
+        try:
+            var_json = json.load(varfile)
+        except JSONDecodeError:
+            raise SystemExit("Invalid variables file.")
+    suc_count = fail_count = 0
+    for k, v in var_json.items():
+        try:
+            Variable.set(k, v, serialize_json=not isinstance(v, str))
+        except Exception as e:
+            print(f"Variable import failed: {repr(e)}")
+            fail_count += 1
+        else:
+            suc_count += 1
+    print(f"{suc_count} of {len(var_json)} variables successfully updated.")
+    if fail_count:
+        print(f"{fail_count} variable(s) failed to be updated.")
 
 
 @providers_configuration_loaded
 def variables_export(args):
     """Exports all the variables to the file."""
-    _variable_export_helper(args.file)
-
-
-def _import_helper(filepath):
-    """Helps import variables from the file."""
-    with open(filepath) as varfile:
-        data = varfile.read()
-
-    try:
-        var_json = json.loads(data)
-    except JSONDecodeError:
-        raise SystemExit("Invalid variables file.")
-    else:
-        suc_count = fail_count = 0
-        for k, v in var_json.items():
-            try:
-                Variable.set(k, v, serialize_json=not isinstance(v, str))
-            except Exception as e:
-                print(f"Variable import failed: {repr(e)}")
-                fail_count += 1
-            else:
-                suc_count += 1
-        print(f"{suc_count} of {len(var_json)} variables successfully 
updated.")
-        if fail_count:
-            print(f"{fail_count} variable(s) failed to be updated.")
-
-
-def _variable_export_helper(filepath):
-    """Helps export all the variables to the file."""
     var_dict = {}
     with create_session() as session:
         qry = session.scalars(select(Variable))
@@ -126,6 +114,9 @@ def _variable_export_helper(filepath):
                 val = var.val
             var_dict[var.key] = val
 
-    with open(filepath, "w") as varfile:
-        varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
-    print(f"{len(var_dict)} variables successfully exported to {filepath}")
+    with args.file as varfile:
+        json.dump(var_dict, varfile, sort_keys=True, indent=4)
+        if is_stdout(varfile):
+            print("\nVariables successfully exported.", file=sys.stderr)
+        else:
+            print(f"Variables successfully exported to {varfile.name}.")
diff --git a/airflow/cli/utils.py b/airflow/cli/utils.py
new file mode 100644
index 0000000000..718d34a6eb
--- /dev/null
+++ b/airflow/cli/utils.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import io
+import sys
+
+
+def is_stdout(fileio: io.IOBase) -> bool:
+    """Check whether a file IO is stdout.
+
+    The intended use case for this helper is to check whether an argument 
parsed
+    with argparse.FileType points to stdout (by setting the path to ``-``). 
This
+    is why there is no equivalent for stderr; argparse does not allow using it.
+
+    .. warning:: *fileio* must be open for this check to be successful.
+    """
+    return fileio.fileno() == sys.stdout.fileno()

Reply via email to