sc250072 commented on code in PR #56675:
URL: https://github.com/apache/airflow/pull/56675#discussion_r2497445439


##########
providers/teradata/src/airflow/providers/teradata/utils/tpt_util.py:
##########
@@ -0,0 +1,521 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+import os
+import shutil
+import subprocess
+import uuid
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from paramiko import SSHClient
+
+from airflow.exceptions import AirflowException
+
+
+class TPTConfig:
+    """Configuration constants for TPT operations."""
+
+    DEFAULT_TIMEOUT = 5
+    FILE_PERMISSIONS_READ_ONLY = 0o400
+    TEMP_DIR_WINDOWS = "C:\\Windows\\Temp"
+    TEMP_DIR_UNIX = "/tmp"
+
+
+def execute_remote_command(ssh_client: SSHClient, command: str) -> tuple[int, 
str, str]:
+    """
+    Execute a command on remote host and properly manage SSH channels.
+
+    :param ssh_client: SSH client connection
+    :param command: Command to execute
+    :return: Tuple of (exit_status, stdout, stderr)
+    """
+    stdin, stdout, stderr = ssh_client.exec_command(command)
+    try:
+        exit_status = stdout.channel.recv_exit_status()
+        stdout_data = stdout.read().decode().strip()
+        stderr_data = stderr.read().decode().strip()
+        return exit_status, stdout_data, stderr_data
+    finally:
+        stdin.close()
+        stdout.close()
+        stderr.close()
+
+
+def write_file(path: str, content: str) -> None:
+    with open(path, "w", encoding="utf-8") as f:
+        f.write(content)
+
+
+def secure_delete(file_path: str, logger: logging.Logger | None = None) -> 
None:
+    """
+    Securely delete a file using shred if available, otherwise use os.remove.
+
+    :param file_path: Path to the file to be deleted
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+    if not os.path.exists(file_path):
+        return
+
+    try:
+        # Check if shred is available
+        if shutil.which("shred") is not None:
+            # Use shred to securely delete the file
+            subprocess.run(["shred", "--remove", file_path], check=True, 
timeout=TPTConfig.DEFAULT_TIMEOUT)
+            logger.info("Securely removed file using shred: %s", file_path)
+        else:
+            # Fall back to regular deletion
+            os.remove(file_path)
+            logger.info("Removed file: %s", file_path)
+
+    except (OSError, subprocess.CalledProcessError, subprocess.TimeoutExpired) 
as e:
+        logger.warning("Failed to remove file %s: %s", file_path, str(e))
+
+
+def remote_secure_delete(
+    ssh_client: SSHClient, remote_files: list[str], logger: logging.Logger | 
None = None
+) -> None:
+    """
+    Securely delete remote files via SSH. Attempts shred first, falls back to 
rm if shred is unavailable.
+
+    :param ssh_client: SSH client connection
+    :param remote_files: List of remote file paths to delete
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+    if not ssh_client or not remote_files:
+        return
+
+    try:
+        # Detect remote OS
+        remote_os = get_remote_os(ssh_client, logger)
+        windows_remote = remote_os == "windows"
+
+        # Check if shred is available on remote system (UNIX/Linux)
+        shred_available = False
+        if not windows_remote:
+            exit_status, output, _ = execute_remote_command(ssh_client, 
"command -v shred")
+            shred_available = exit_status == 0 and output.strip() != ""
+
+        for file_path in remote_files:
+            try:
+                if windows_remote:
+                    # Windows remote host - use del command
+                    replace_slash = file_path.replace("/", "\\")
+                    execute_remote_command(
+                        ssh_client, f'if exist "{replace_slash}" del /f /q 
"{replace_slash}"'
+                    )
+                elif shred_available:
+                    # UNIX/Linux with shred
+                    execute_remote_command(ssh_client, f"shred --remove 
{file_path}")
+                else:
+                    # UNIX/Linux without shred - overwrite then delete
+                    execute_remote_command(
+                        ssh_client,
+                        f"if [ -f {file_path} ]; then "
+                        f"dd if=/dev/zero of={file_path} bs=4096 
count=$(($(stat -c '%s' {file_path})/4096+1)) 2>/dev/null; "
+                        f"rm -f {file_path}; fi",
+                    )
+            except Exception as e:
+                logger.warning("Failed to process remote file %s: %s", 
file_path, str(e))
+
+        logger.info("Processed remote files: %s", ", ".join(remote_files))
+    except Exception as e:
+        logger.warning("Failed to remove remote files: %s", str(e))
+
+
+def terminate_subprocess(sp: subprocess.Popen | None, logger: logging.Logger | 
None = None) -> None:
+    """
+    Terminate a subprocess gracefully with proper error handling.
+
+    :param sp: Subprocess to terminate
+    :param logger: Optional logger instance
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not sp or sp.poll() is not None:
+        # Process is None or already terminated
+        return
+
+    logger.info("Terminating subprocess (PID: %s)", sp.pid)
+
+    try:
+        sp.terminate()  # Attempt to terminate gracefully
+        sp.wait(timeout=TPTConfig.DEFAULT_TIMEOUT)
+        logger.info("Subprocess terminated gracefully")
+    except subprocess.TimeoutExpired:
+        logger.warning(
+            "Subprocess did not terminate gracefully within %d seconds, 
killing it", TPTConfig.DEFAULT_TIMEOUT
+        )
+        try:
+            sp.kill()
+            sp.wait(timeout=2)  # Brief wait after kill
+            logger.info("Subprocess killed successfully")
+        except Exception as e:
+            logger.error("Error killing subprocess: %s", str(e))
+    except Exception as e:
+        logger.error("Error terminating subprocess: %s", str(e))
+
+
+def get_remote_os(ssh_client: SSHClient, logger: logging.Logger | None = None) 
-> str:
+    """
+    Detect the operating system of the remote host via SSH.
+
+    :param ssh_client: SSH client connection
+    :param logger: Optional logger instance
+    :return: Operating system type as string ('windows' or 'unix')
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not ssh_client:
+        logger.warning("No SSH client provided for OS detection")
+        return "unix"
+
+    try:
+        # Check for Windows first
+        exit_status, stdout_data, stderr_data = 
execute_remote_command(ssh_client, "echo %OS%")
+
+        if "Windows" in stdout_data:
+            return "windows"
+
+        # All other systems are treated as Unix-like
+        return "unix"
+
+    except Exception as e:
+        logger.error("Error detecting remote OS: %s", str(e))
+        return "unix"
+
+
+def set_local_file_permissions(local_file_path: str, logger: logging.Logger | 
None = None) -> None:
+    """
+    Set permissions for a local file to be read-only for the owner.
+
+    :param local_file_path: Path to the local file
+    :param logger: Optional logger instance
+    :raises AirflowException: If permission setting fails
+    """
+    logger = logger or logging.getLogger(__name__)
+
+    if not local_file_path:
+        logger.warning("No file path provided for permission setting")
+        return
+
+    if not os.path.exists(local_file_path):
+        raise AirflowException(f"File does not exist: {local_file_path}")
+
+    try:
+        # Set file permission to read-only for the owner (400)
+        os.chmod(local_file_path, TPTConfig.FILE_PERMISSIONS_READ_ONLY)
+        logger.info("Set read-only permissions for file %s", local_file_path)
+    except (OSError, PermissionError) as e:
+        raise AirflowException(f"Error setting permissions for local file 
{local_file_path}: {str(e)}")
+
+
+def _set_windows_file_permissions(
+    ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger
+) -> None:
+    """Set restrictive permissions on Windows remote file."""
+    command = f'icacls "{remote_file_path}" /inheritance:r /grant:r 
"%USERNAME%":R'
+
+    exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, 
command)
+
+    if exit_status != 0:
+        raise AirflowException(

Review Comment:
   Updated the code to use specific exceptions instead of AirflowException, as 
per the discussion. Could you please review the changes and let us know if any 
further updates are needed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to