sc250072 commented on code in PR #56675:
URL: https://github.com/apache/airflow/pull/56675#discussion_r2450452195


##########
providers/teradata/src/airflow/providers/teradata/utils/tpt_util.py:
##########
@@ -0,0 +1,521 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+import os
+import shutil
+import subprocess
+import uuid
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from paramiko import SSHClient
+
+from airflow.exceptions import AirflowException
+
+
+class TPTConfig:
+    """Configuration constants for TPT operations."""
+
+    DEFAULT_TIMEOUT = 5
+    FILE_PERMISSIONS_READ_ONLY = 0o400
+    TEMP_DIR_WINDOWS = "C:\\Windows\\Temp"
+    TEMP_DIR_UNIX = "/tmp"
+
+
+def execute_remote_command(ssh_client: SSHClient, command: str) -> tuple[int, 
str, str]:
+    """
+    Execute a command on remote host and properly manage SSH channels.
+
+    :param ssh_client: SSH client connection
+    :param command: Command to execute
+    :return: Tuple of (exit_status, stdout, stderr)
+    """
+    stdin, stdout, stderr = ssh_client.exec_command(command)
+    try:
+        exit_status = stdout.channel.recv_exit_status()
+        stdout_data = stdout.read().decode().strip()
+        stderr_data = stderr.read().decode().strip()
+        return exit_status, stdout_data, stderr_data
+    finally:
+        stdin.close()
+        stdout.close()
+        stderr.close()
+
+
+def write_file(path: str, content: str) -> None:
+    with open(path, "w", encoding="utf-8") as f:
+        f.write(content)
+
+
+def secure_delete(file_path: str, logger: logging.Logger | None = None) -> 
None:
+    """
+    Securely delete a file using shred if available, otherwise use os.remove.
+
+    :param file_path: Path to the file to be deleted
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+    if not os.path.exists(file_path):
+        return
+
+    try:
+        # Check if shred is available
+        if shutil.which("shred") is not None:
+            # Use shred to securely delete the file
+            subprocess.run(["shred", "--remove", file_path], check=True, 
timeout=TPTConfig.DEFAULT_TIMEOUT)
+            logger.info("Securely removed file using shred: %s", file_path)
+        else:
+            # Fall back to regular deletion
+            os.remove(file_path)
+            logger.info("Removed file: %s", file_path)
+
+    except (OSError, subprocess.CalledProcessError, subprocess.TimeoutExpired) 
as e:
+        logger.warning("Failed to remove file %s: %s", file_path, str(e))
+
+
+def remote_secure_delete(
+    ssh_client: SSHClient, remote_files: list[str], logger: logging.Logger | 
None = None
+) -> None:
+    """
+    Securely delete remote files via SSH. Attempts shred first, falls back to 
rm if shred is unavailable.
+
+    :param ssh_client: SSH client connection
+    :param remote_files: List of remote file paths to delete
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+    if not ssh_client or not remote_files:
+        return
+
+    try:
+        # Detect remote OS
+        remote_os = get_remote_os(ssh_client, logger)
+        windows_remote = remote_os == "windows"
+
+        # Check if shred is available on remote system (UNIX/Linux)
+        shred_available = False
+        if not windows_remote:
+            exit_status, output, _ = execute_remote_command(ssh_client, 
"command -v shred")
+            shred_available = exit_status == 0 and output.strip() != ""
+
+        for file_path in remote_files:
+            try:
+                if windows_remote:
+                    # Windows remote host - use del command
+                    replace_slash = file_path.replace("/", "\\")
+                    execute_remote_command(
+                        ssh_client, f'if exist "{replace_slash}" del /f /q 
"{replace_slash}"'
+                    )
+                elif shred_available:
+                    # UNIX/Linux with shred
+                    execute_remote_command(ssh_client, f"shred --remove 
{file_path}")
+                else:
+                    # UNIX/Linux without shred - overwrite then delete
+                    execute_remote_command(
+                        ssh_client,
+                        f"if [ -f {file_path} ]; then "
+                        f"dd if=/dev/zero of={file_path} bs=4096 
count=$(($(stat -c '%s' {file_path})/4096+1)) 2>/dev/null; "
+                        f"rm -f {file_path}; fi",
+                    )
+            except Exception as e:
+                logger.warning("Failed to process remote file %s: %s", 
file_path, str(e))
+
+        logger.info("Processed remote files: %s", ", ".join(remote_files))
+    except Exception as e:
+        logger.warning("Failed to remove remote files: %s", str(e))
+
+
+def terminate_subprocess(sp: subprocess.Popen | None, logger: logging.Logger | 
None = None) -> None:
+    """
+    Terminate a subprocess gracefully with proper error handling.
+
+    :param sp: Subprocess to terminate
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not sp or sp.poll() is not None:
+        # Process is None or already terminated
+        return
+
+    logger.info("Terminating subprocess (PID: %s)", sp.pid)
+
+    try:
+        sp.terminate()  # Attempt to terminate gracefully
+        sp.wait(timeout=TPTConfig.DEFAULT_TIMEOUT)
+        logger.info("Subprocess terminated gracefully")
+    except subprocess.TimeoutExpired:
+        logger.warning(
+            "Subprocess did not terminate gracefully within %d seconds, 
killing it", TPTConfig.DEFAULT_TIMEOUT
+        )
+        try:
+            sp.kill()
+            sp.wait(timeout=2)  # Brief wait after kill
+            logger.info("Subprocess killed successfully")
+        except Exception as e:
+            logger.error("Error killing subprocess: %s", str(e))
+    except Exception as e:
+        logger.error("Error terminating subprocess: %s", str(e))
+
+
+def get_remote_os(ssh_client: SSHClient, logger: logging.Logger | None = None) 
-> str:
+    """
+    Detect the operating system of the remote host via SSH.
+
+    :param ssh_client: SSH client connection
+    :param logger: Optional logger instance
+    :return: Operating system type as string ('windows' or 'unix')
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not ssh_client:
+        logger.warning("No SSH client provided for OS detection")
+        return "unix"
+
+    try:
+        # Check for Windows first
+        exit_status, stdout_data, stderr_data = 
execute_remote_command(ssh_client, "echo %OS%")
+
+        if "Windows" in stdout_data:
+            return "windows"
+
+        # All other systems are treated as Unix-like
+        return "unix"
+
+    except Exception as e:
+        logger.error("Error detecting remote OS: %s", str(e))
+        return "unix"
+
+
+def set_local_file_permissions(local_file_path: str, logger: logging.Logger | 
None = None) -> None:
+    """
+    Set permissions for a local file to be read-only for the owner.
+
+    :param local_file_path: Path to the local file
+    :param logger: Optional logger instance
+    :raises AirflowException: If permission setting fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not local_file_path:
+        logger.warning("No file path provided for permission setting")
+        return
+
+    if not os.path.exists(local_file_path):
+        raise AirflowException(f"File does not exist: {local_file_path}")
+
+    try:
+        # Set file permission to read-only for the owner (400)
+        os.chmod(local_file_path, TPTConfig.FILE_PERMISSIONS_READ_ONLY)
+        logger.info("Set read-only permissions for file %s", local_file_path)
+    except (OSError, PermissionError) as e:
+        raise AirflowException(f"Error setting permissions for local file 
{local_file_path}: {str(e)}")
+
+
+def _set_windows_file_permissions(
+    ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger
+) -> None:
+    """Set restrictive permissions on Windows remote file."""
+    command = f'icacls "{remote_file_path}" /inheritance:r /grant:r 
"%USERNAME%":R'
+
+    exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, 
command)
+
+    if exit_status != 0:
+        raise AirflowException(
+            f"Failed to set restrictive permissions on Windows remote file 
{remote_file_path}. "
+            f"Exit status: {exit_status}, Error: {stderr_data if stderr_data 
else 'N/A'}"
+        )
+    logger.info("Set restrictive permissions (owner read-only) for Windows 
remote file %s", remote_file_path)
+
+
+def _set_unix_file_permissions(ssh_client: SSHClient, remote_file_path: str, 
logger: logging.Logger) -> None:
+    """Set read-only permissions on Unix/Linux remote file."""
+    command = f"chmod 400 {remote_file_path}"
+
+    exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, 
command)
+
+    if exit_status != 0:
+        raise AirflowException(
+            f"Failed to set permissions (400) on remote file 
{remote_file_path}. "
+            f"Exit status: {exit_status}, Error: {stderr_data if stderr_data 
else 'N/A'}"
+        )
+    logger.info("Set read-only permissions for remote file %s", 
remote_file_path)
+
+
+def set_remote_file_permissions(
+    ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger | 
None = None
+) -> None:
+    """
+    Set permissions for a remote file to be read-only for the owner.
+
+    :param ssh_client: SSH client connection
+    :param remote_file_path: Path to the remote file
+    :param logger: Optional logger instance
+    :raises AirflowException: If permission setting fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not ssh_client or not remote_file_path:
+        logger.warning(
+            "Invalid parameters: ssh_client=%s, remote_file_path=%s", 
bool(ssh_client), remote_file_path
+        )
+        return
+
+    try:
+        # Detect remote OS once
+        remote_os = get_remote_os(ssh_client, logger)
+
+        if remote_os == "windows":
+            _set_windows_file_permissions(ssh_client, remote_file_path, logger)
+        else:
+            _set_unix_file_permissions(ssh_client, remote_file_path, logger)
+
+    except AirflowException:
+        # Re-raise AirflowExceptions as-is
+        raise
+    except Exception as e:
+        raise AirflowException(f"Error setting permissions for remote file 
{remote_file_path}: {str(e)}")
+
+
+def get_remote_temp_directory(ssh_client: SSHClient, logger: logging.Logger | 
None = None) -> str:
+    """
+    Get the remote temporary directory path based on the operating system.
+
+    :param ssh_client: SSH client connection
+    :param logger: Optional logger instance
+    :return: Path to the remote temporary directory
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    try:
+        # Detect OS once
+        remote_os = get_remote_os(ssh_client, logger)
+
+        if remote_os == "windows":
+            exit_status, temp_dir, stderr_data = 
execute_remote_command(ssh_client, "echo %TEMP%")
+
+            if exit_status == 0 and temp_dir and temp_dir != "%TEMP%":
+                return temp_dir
+            logger.warning("Could not get TEMP directory, using default: %s", 
TPTConfig.TEMP_DIR_WINDOWS)
+            return TPTConfig.TEMP_DIR_WINDOWS
+
+        # Unix/Linux - use /tmp
+        return TPTConfig.TEMP_DIR_UNIX
+
+    except Exception as e:
+        logger.warning("Error getting remote temp directory: %s", str(e))
+        return TPTConfig.TEMP_DIR_UNIX
+
+
+def verify_tpt_utility_installed(utility: str) -> None:
+    """Verify if a TPT utility (e.g., tbuild) is installed and available in 
the system's PATH."""
+    if shutil.which(utility) is None:
+        raise AirflowException(
+            f"TPT utility '{utility}' is not installed or not available in the 
system's PATH"
+        )
+
+
+def verify_tpt_utility_on_remote_host(
+    ssh_client: SSHClient, utility: str, logger: logging.Logger | None = None
+) -> None:
+    """
+    Verify if a TPT utility (tbuild) is installed on the remote host via SSH.
+
+    :param ssh_client: SSH client connection
+    :param utility: Name of the utility to verify
+    :param logger: Optional logger instance
+    :raises AirflowException: If utility is not found or verification fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    try:
+        # Detect remote OS once
+        remote_os = get_remote_os(ssh_client, logger)
+
+        if remote_os == "windows":
+            command = f"where {utility}"
+        else:
+            command = f"which {utility}"
+
+        exit_status, output, error = execute_remote_command(ssh_client, 
command)
+
+        if exit_status != 0 or not output:
+            raise AirflowException(
+                f"TPT utility '{utility}' is not installed or not available in 
PATH on the remote host. "
+                f"Command: {command}, Exit status: {exit_status}, "
+                f"stderr: {error if error else 'N/A'}"
+            )
+
+        logger.info("TPT utility '%s' found at: %s", utility, 
output.split("\n")[0])
+
+    except AirflowException:
+        # Re-raise AirflowExceptions as-is
+        raise
+    except Exception as e:
+        raise AirflowException(f"Failed to verify TPT utility '{utility}' on 
remote host: {str(e)}")
+
+
+def prepare_tpt_ddl_script(
+    sql: list[str],
+    error_list: list[int] | None,
+    source_conn: dict[str, Any],
+    job_name: str | None = None,
+) -> str:
+    """
+    Prepare a TPT script for executing DDL statements.
+
+    This method generates a TPT script that defines a DDL operator and applies 
the provided SQL statements.
+    It also supports specifying a list of error codes to handle during the 
operation.
+
+    :param sql: A list of DDL statements to execute.
+    :param error_list: A list of error codes to handle during the operation.
+    :param source_conn: Connection details for the source database.
+    :param job_name: The name of the TPT job. Defaults to unique name if None.
+    :return: A formatted TPT script as a string.
+    :raises ValueError: If the SQL statement list is empty.
+    """
+    if not sql or not isinstance(sql, list):
+        raise ValueError("SQL statement list must be a non-empty list")
+
+    # Clean and escape each SQL statement:
+    sql_statements = [
+        stmt.strip().rstrip(";").replace("'", "''")
+        for stmt in sql
+        if stmt and isinstance(stmt, str) and stmt.strip()
+    ]
+
+    if not sql_statements:
+        raise ValueError("No valid SQL statements found in the provided input")
+
+    # Format for TPT APPLY block, indenting after the first line
+    apply_sql = ",\n".join(
+        [f"('{stmt};')" if i == 0 else f"            ('{stmt};')" for i, stmt 
in enumerate(sql_statements)]
+    )
+
+    if job_name is None:
+        job_name = f"airflow_tptddl_{uuid.uuid4().hex}"
+
+    # Format error list for inclusion in the TPT script
+    if not error_list:
+        error_list_stmt = "ErrorList = ['']"
+    else:
+        error_list_str = ", ".join([f"'{error}'" for error in error_list])
+        error_list_stmt = f"ErrorList = [{error_list_str}]"
+
+    host = source_conn["host"]
+    login = source_conn["login"]
+    password = source_conn["password"]
+
+    tpt_script = f"""
+        DEFINE JOB {job_name}
+        DESCRIPTION 'TPT DDL Operation'
+        (
+        APPLY
+            {apply_sql}
+        TO OPERATOR ( $DDL ()
+            ATTR
+            (
+                TdpId = '{host}',
+                UserName = '{login}',
+                UserPassword = '{password}',
+                {error_list_stmt}
+            )
+        );
+        );
+        """
+    return tpt_script
+
+
+def decrypt_remote_file(
+    ssh_client: SSHClient,
+    remote_enc_file: str,
+    remote_dec_file: str,
+    password: str,
+    logger: logging.Logger | None = None,
+) -> int:
+    """
+    Decrypt a remote file using OpenSSL.
+
+    :param ssh_client: SSH client connection
+    :param remote_enc_file: Path to the encrypted file
+    :param remote_dec_file: Path for the decrypted file
+    :param password: Decryption password
+    :param logger: Optional logger instance
+    :return: Exit status of the decryption command
+    :raises AirflowException: If decryption fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    # Detect remote OS
+    remote_os = get_remote_os(ssh_client, logger)
+    windows_remote = remote_os == "windows"
+
+    if windows_remote:
+        # Windows - use different quoting and potentially different OpenSSL 
parameters
+        password_escaped = password.replace('"', '""')  # Escape double quotes 
for Windows
+        decrypt_cmd = (
+            f'openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass 
pass:"{password_escaped}" '
+            f'-in "{remote_enc_file}" -out "{remote_dec_file}"'
+        )
+    else:
+        # Unix/Linux - use single quote escaping
+        password_escaped = password.replace("'", "'\\''")  # Escape single 
quotes
+        decrypt_cmd = (
+            f"openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass 
pass:'{password_escaped}' "
+            f"-in {remote_enc_file} -out {remote_dec_file}"
+        )
+
+    exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, 
decrypt_cmd)
+
+    if exit_status != 0:
+        raise AirflowException(
+            f"Decryption failed with exit status {exit_status}. Error: 
{stderr_data if stderr_data else 'N/A'}"
+        )
+
+    logger.info("Successfully decrypted remote file %s to %s", 
remote_enc_file, remote_dec_file)
+    return exit_status
+
+
+def transfer_file_sftp(
+    ssh_client: SSHClient, local_path: str, remote_path: str, logger: 
logging.Logger | None = None
+) -> None:
+    """
+    Transfer a file from local to remote host using SFTP.
+
+    :param ssh_client: SSH client connection
+    :param local_path: Local file path
+    :param remote_path: Remote file path
+    :param logger: Optional logger instance
+    :raises AirflowException: If file transfer fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not os.path.exists(local_path):
+        raise AirflowException(f"Local file does not exist: {local_path}")
+
+    sftp = None
+    try:
+        sftp = ssh_client.open_sftp()
+        sftp.put(local_path, remote_path)
+        logger.info("Successfully transferred file from %s to %s", local_path, 
remote_path)
+    except Exception as e:
+        raise AirflowException(f"Failed to transfer file from {local_path} to 
{remote_path}: {str(e)}")

Review Comment:
   Sure @potiuk  will address this in all places and will raise new PR



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to