potiuk commented on a change in pull request #19806:
URL: https://github.com/apache/airflow/pull/19806#discussion_r796608439



##########
File path: setup.py
##########
@@ -450,7 +450,7 @@ def write_version(filename: str = os.path.join(*[my_dir, 
"airflow", "git_version
     pandas_requirement,
 ]
 psrp = [
-    'pypsrp~=0.5',
+    'pypsrp~=0.8',

Review comment:
       Is 0.8 really necessary here? We usully limit the versions if we know 
earlier versions will not work - just to avoid too "strict" limits. Is there 
something that prevents us to use ~0.7 here? Also how about `> 0.8` ?  We have 
discussion about not limiting upperbound because we do not really know the 
future and constraints are protecting our users, so limitng upperbound (in this 
case <1.0) might not be justified at all. 

##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):
     """
     Hook for PowerShell Remoting Protocol execution.
 
-    The hook must be used as a context manager.
+    When used as a context manager, the runspace pool is reused between shell
+    sessions.
+
+    :param psrp_conn_id: Required. The name of the PSRP connection.
+    :type psrp_conn_id: str
+    :param logging: If true (default), log command output and streams during 
execution.
+    :type logging: bool
+    :param operation_timeout: Override the default WSMan timeout when polling 
the pipeline.
+    :type operation_timeout: float
+    :param runspace_options:
+        Optional dictionary which is passed when creating the runspace pool. 
See
+        :py:class:`~pypsrp.powershell.RunspacePool` for a description of the
+        available options.
+    :type runspace_options: dict
+    :param wsman_options:
+        Optional dictionary which is passed when creating the `WSMan` client. 
See
+        :py:class:`~pypsrp.wsman.WSMan` for a description of the available 
options.
+    :type wsman_options: dict
+    :param exchange_keys:
+        If true (default), automatically initiate a session key exchange when 
the
+        hook is used as a context manager.
+    :type exchange_keys: bool
+
+    You can provide an alternative `configuration_name` using either 
`runspace_options`
+    or by setting this key as the extra fields of your connection.
     """
 
-    _client = None
-    _poll_interval = 1
-
-    def __init__(self, psrp_conn_id: str):
+    _conn = None
+    _configuration_name = None
+    _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary()
+
+    def __init__(
+        self,
+        psrp_conn_id: str,
+        logging: bool = True,
+        operation_timeout: Optional[float] = None,
+        runspace_options: Optional[Dict[str, Any]] = None,
+        wsman_options: Optional[Dict[str, Any]] = None,
+        exchange_keys: bool = True,
+    ):
         self.conn_id = psrp_conn_id
+        self._logging = logging
+        self._operation_timeout = operation_timeout
+        self._runspace_options = runspace_options or {}
+        self._wsman_options = wsman_options or {}
+        self._exchange_keys = exchange_keys
 
     def __enter__(self):
-        conn = self.get_connection(self.conn_id)
-
-        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
-        self._client = WSMan(
-            conn.host,
-            ssl=True,
-            auth="ntlm",
-            encryption="never",
-            username=conn.login,
-            password=conn.password,
-            cert_validation=False,
-        )
-        self._client.__enter__()
+        conn = self.get_conn()
+        self._wsman_ref[conn].__enter__()
+        conn.__enter__()
+        if self._exchange_keys:
+            conn.exchange_keys()
+        self._conn = conn
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         try:
-            self._client.__exit__(exc_type, exc_value, traceback)
+            self._conn.__exit__(exc_type, exc_value, traceback)
+            self._wsman_ref[self._conn].__exit__(exc_type, exc_value, 
traceback)
         finally:
-            self._client = None
+            del self._conn
 
-    def invoke_powershell(self, script: str) -> PowerShell:
-        with RunspacePool(self._client) as pool:
-            ps = PowerShell(pool)
-            ps.add_script(script)
+    def get_conn(self) -> RunspacePool:
+        """
+        Returns a runspace pool.
+
+        The returned object must be used as a context manager.
+        """
+        conn = self.get_connection(self.conn_id)
+        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
+
+        extra = conn.extra_dejson.copy()
+
+        def apply_extra(d, keys):
+            d = d.copy()
+            for key in keys:
+                value = extra.pop(key, None)
+                if value is not None:
+                    d[key] = value
+            return d
+
+        wsman_options = apply_extra(
+            self._wsman_options,
+            (
+                "auth",
+                "cert_validation",
+                "connection_timeout",
+                "locale",
+                "read_timeout",
+                "reconnection_retries",
+                "reconnection_backoff",
+                "ssl",
+            ),
+        )
+        wsman = WSMan(conn.host, username=conn.login, password=conn.password, 
**wsman_options)
+        runspace_options = apply_extra(self._runspace_options, 
("configuration_name",))
+
+        if extra:
+            raise AirflowException(f"Unexpected extra configuration keys: {', 
'.join(sorted(extra))}")
+        pool = RunspacePool(wsman, **runspace_options)
+        self._wsman_ref[pool] = wsman
+        return pool
+
+    @contextmanager
+    def invoke(self) -> PowerShell:
+        """
+        Context manager that yields a PowerShell object to which commands can 
be
+        added. Upon exit, the commands will be invoked.
+        """
+        local_context = self._conn is None
+        if local_context:
+            self.__enter__()
+        try:
+            ps = PowerShell(self._conn)
+            yield ps
             ps.begin_invoke()
-            streams = [
-                (ps.output, self._log_output),
-                (ps.streams.debug, self._log_record),
-                (ps.streams.information, self._log_record),
-                (ps.streams.error, self._log_record),
-            ]
-            offsets = [0 for _ in streams]
-
-            # We're using polling to make sure output and streams are
-            # handled while the process is running.
-            while ps.state == PSInvocationState.RUNNING:
-                sleep(self._poll_interval)
-                ps.poll_invoke()
-
-                for (i, (stream, handler)) in enumerate(streams):
-                    offset = offsets[i]
-                    while len(stream) > offset:
-                        handler(stream[offset])
-                        offset += 1
-                    offsets[i] = offset
+            if self._logging:
+                streams = [
+                    (ps.streams.debug, self._log_record),
+                    (ps.streams.error, self._log_record),
+                    (ps.streams.information, self._log_record),
+                    (ps.streams.progress, self._log_record),
+                    (ps.streams.verbose, self._log_record),
+                    (ps.streams.warning, self._log_record),
+                ]
+                offsets = [0 for _ in streams]
+
+                # We're using polling to make sure output and streams are
+                # handled while the process is running.
+                while ps.state == PSInvocationState.RUNNING:
+                    ps.poll_invoke(timeout=self._operation_timeout)
+
+                    for (i, (stream, handler)) in enumerate(streams):
+                        offset = offsets[i]
+                        while len(stream) > offset:
+                            handler(stream[offset])
+                            offset += 1
+                        offsets[i] = offset
 
             # For good measure, we'll make sure the process has
-            # stopped running.
+            # stopped running in any case.
             ps.end_invoke()
 
+            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
             if ps.streams.error:
                 raise AirflowException("Process had one or more errors")
+        finally:
+            if local_context:
+                self.__exit__(None, None, None)
 
-            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
-            return ps
+    def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: 
Dict[str, str]) -> PowerShell:
+        """Invoke a PowerShell cmdlet and return session."""
+        with self.invoke() as ps:
+            ps.add_cmdlet(name, use_local_scope=use_local_scope)
+            ps.add_parameters(parameters)
+        return ps
 
-    def _log_output(self, message: str):
-        self.log.info("%s", message)
+    def invoke_powershell(self, script: str) -> PowerShell:
+        """Invoke a PowerShell script and return session."""
+        with self.invoke() as ps:
+            ps.add_script(script)
+        return ps
 
     def _log_record(self, record):
-        # TODO: Consider translating some or all of these records into
-        # normal logging levels, using `log(level, msg, *args)`.
-        if isinstance(record, ErrorRecord):
-            self.log.info("Error: %s", record)
-            return
-
-        if isinstance(record, InformationRecord):
-            self.log.info("Information: %s", record.message_data)
-            return
-
-        if isinstance(record, ProgressRecord):
+        message_type = getattr(record, "MESSAGE_TYPE", None)
+
+        # There seems to be a problem with some record types; we'll assume
+        # that the class name matches a message type.
+        if message_type is None:
+            message_type = getattr(
+                MessageType, re.sub('(?!^)([A-Z]+)', r'_\1', 
type(record).__name__).upper()
+            )
+
+        if message_type == MessageType.ERROR_RECORD:
+            self.log.info("%s: %s", record.reason, record)
+            if record.script_stacktrace:
+                for trace in record.script_stacktrace.split('\r\n'):

Review comment:
       I think using `splitlines()` is indeed better, at least it raises a brow 
for me - why not splitlines()?  
   
   The PSRP <> Windows association is not at all obvious for someone who just 
reviews the code)?. 
   I tihnk it should either explain it in the comment why we are doing it, or 
(better IMHO) use splitlines and forget about it .




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to