uranusjr commented on a change in pull request #19806:
URL: https://github.com/apache/airflow/pull/19806#discussion_r773539930



##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):
     """
     Hook for PowerShell Remoting Protocol execution.
 
-    The hook must be used as a context manager.
+    When used as a context manager, the runspace pool is reused between shell
+    sessions.

Review comment:
       How important is it to you to support the non-context manager use case? 
I feel it’s kind of unnecessarily complicating the implementation.

##########
File path: airflow/providers/microsoft/psrp/operators/psrp.py
##########
@@ -16,51 +16,129 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
+
+from jinja2.nativetypes import NativeEnvironment
+from pypsrp.serializer import TaggedValue
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.microsoft.psrp.hooks.psrp import PSRPHook
+from airflow.settings import json
 
 
 class PSRPOperator(BaseOperator):
     """PowerShell Remoting Protocol operator.
 
-    :param psrp_conn_id: connection id
+    Use one of the 'command', 'cmdlet', or 'powershell' arguments.
+
+    The 'securestring' template filter can be used to tag a value for
+    serialization into a `System.Security.SecureString` (applicable only
+    for DAGs which have `render_template_as_native_obj=True`).
+
+    The command output is converted to JSON by PowerShell such that the 
operator
+    return value is serializable to an XCom value.
+
+    :param psrp_conn_id: Connection id
     :type psrp_conn_id: str
-    :param command: command to execute on remote host. (templated)
+    :param command: Command to execute on remote host (templated).
     :type command: str
-    :param powershell: powershell to execute on remote host. (templated)
+    :param powershell: Powershell to execute on remote host (templated)
     :type powershell: str
+    :param cmdlet:
+        Cmdlet to execute on remote host (templated). Also used as the default
+        value for `task_id`.
+    :type cmdlet: str
+    :param parameters:
+        Parameters to provide to cmdlet (templated). This is allowed only if
+        the `cmdlet` parameter is also given.
+    :type parameters: dict
+    :param logging: If true (default), log command output and streams during 
execution.
+    :type logging: bool
+    :param runspace_options:
+        Optional dictionary which is passed when creating the runspace pool. 
See
+        :py:class:`~pypsrp.powershell.RunspacePool` for a description of the
+        available options.
+    :type runspace_options: dict
+    :param wsman_options:
+        Optional dictionary which is passed when creating the `WSMan` client. 
See
+        :py:class:`~pypsrp.wsman.WSMan` for a description of the available 
options.
+    :type wsman_options: dict
     """
 
     template_fields = (
+        "cmdlet",
         "command",
+        "parameters",
         "powershell",
     )
     template_fields_renderers = {"command": "powershell", "powershell": 
"powershell"}
-    ui_color = "#901dd2"
+    ui_color = "#c2e2ff"
 
     def __init__(
         self,
         *,
         psrp_conn_id: str,
         command: Optional[str] = None,
         powershell: Optional[str] = None,
+        cmdlet: Optional[str] = None,
+        parameters: Optional[Dict[str, str]] = None,
+        logging: bool = True,
+        runspace_options: Optional[Dict[str, Any]] = None,
+        wsman_options: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> None:
+        args = {command, powershell, cmdlet}
+        if not len(list(filter(None, args))) == 1:
+            raise ValueError("Must provide either 'command', 'powershell', or 
'cmdlet'")

Review comment:
       Can this not use the same logic as `exactly_one`?
   
   In either case you want to change `not ... ==` to `!=`. Also the message 
should say “exaxtly one” instead of “either” (which implies it’s allowed to 
pass multiple of them).

##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):
     """
     Hook for PowerShell Remoting Protocol execution.
 
-    The hook must be used as a context manager.
+    When used as a context manager, the runspace pool is reused between shell
+    sessions.
+
+    :param psrp_conn_id: Required. The name of the PSRP connection.
+    :type psrp_conn_id: str
+    :param logging: If true (default), log command output and streams during 
execution.
+    :type logging: bool
+    :param operation_timeout: Override the default WSMan timeout when polling 
the pipeline.
+    :type operation_timeout: float
+    :param runspace_options:
+        Optional dictionary which is passed when creating the runspace pool. 
See
+        :py:class:`~pypsrp.powershell.RunspacePool` for a description of the
+        available options.
+    :type runspace_options: dict
+    :param wsman_options:
+        Optional dictionary which is passed when creating the `WSMan` client. 
See
+        :py:class:`~pypsrp.wsman.WSMan` for a description of the available 
options.
+    :type wsman_options: dict
+    :param exchange_keys:
+        If true (default), automatically initiate a session key exchange when 
the
+        hook is used as a context manager.
+    :type exchange_keys: bool
+
+    You can provide an alternative `configuration_name` using either 
`runspace_options`
+    or by setting this key as the extra fields of your connection.
     """
 
-    _client = None
-    _poll_interval = 1
-
-    def __init__(self, psrp_conn_id: str):
+    _conn = None
+    _configuration_name = None
+    _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary()
+
+    def __init__(
+        self,
+        psrp_conn_id: str,
+        logging: bool = True,
+        operation_timeout: Optional[float] = None,
+        runspace_options: Optional[Dict[str, Any]] = None,
+        wsman_options: Optional[Dict[str, Any]] = None,
+        exchange_keys: bool = True,
+    ):
         self.conn_id = psrp_conn_id
+        self._logging = logging
+        self._operation_timeout = operation_timeout
+        self._runspace_options = runspace_options or {}
+        self._wsman_options = wsman_options or {}
+        self._exchange_keys = exchange_keys
 
     def __enter__(self):
-        conn = self.get_connection(self.conn_id)
-
-        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
-        self._client = WSMan(
-            conn.host,
-            ssl=True,
-            auth="ntlm",
-            encryption="never",
-            username=conn.login,
-            password=conn.password,
-            cert_validation=False,
-        )
-        self._client.__enter__()
+        conn = self.get_conn()
+        self._wsman_ref[conn].__enter__()
+        conn.__enter__()
+        if self._exchange_keys:
+            conn.exchange_keys()
+        self._conn = conn
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         try:
-            self._client.__exit__(exc_type, exc_value, traceback)
+            self._conn.__exit__(exc_type, exc_value, traceback)
+            self._wsman_ref[self._conn].__exit__(exc_type, exc_value, 
traceback)
         finally:
-            self._client = None
+            del self._conn
 
-    def invoke_powershell(self, script: str) -> PowerShell:
-        with RunspacePool(self._client) as pool:
-            ps = PowerShell(pool)
-            ps.add_script(script)
+    def get_conn(self) -> RunspacePool:
+        """
+        Returns a runspace pool.
+
+        The returned object must be used as a context manager.
+        """
+        conn = self.get_connection(self.conn_id)
+        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
+
+        extra = conn.extra_dejson.copy()
+
+        def apply_extra(d, keys):
+            d = d.copy()
+            for key in keys:
+                value = extra.pop(key, None)
+                if value is not None:
+                    d[key] = value
+            return d
+
+        wsman_options = apply_extra(
+            self._wsman_options,
+            (
+                "auth",
+                "cert_validation",
+                "connection_timeout",
+                "locale",
+                "read_timeout",
+                "reconnection_retries",
+                "reconnection_backoff",
+                "ssl",
+            ),
+        )
+        wsman = WSMan(conn.host, username=conn.login, password=conn.password, 
**wsman_options)
+        runspace_options = apply_extra(self._runspace_options, 
("configuration_name",))
+
+        if extra:
+            raise AirflowException(f"Unexpected extra configuration keys: {', 
'.join(sorted(extra))}")
+        pool = RunspacePool(wsman, **runspace_options)
+        self._wsman_ref[pool] = wsman
+        return pool
+
+    @contextmanager
+    def invoke(self) -> PowerShell:
+        """
+        Context manager that yields a PowerShell object to which commands can 
be
+        added. Upon exit, the commands will be invoked.
+        """
+        local_context = self._conn is None
+        if local_context:
+            self.__enter__()
+        try:
+            ps = PowerShell(self._conn)
+            yield ps
             ps.begin_invoke()
-            streams = [
-                (ps.output, self._log_output),
-                (ps.streams.debug, self._log_record),
-                (ps.streams.information, self._log_record),
-                (ps.streams.error, self._log_record),
-            ]
-            offsets = [0 for _ in streams]
-
-            # We're using polling to make sure output and streams are
-            # handled while the process is running.
-            while ps.state == PSInvocationState.RUNNING:
-                sleep(self._poll_interval)
-                ps.poll_invoke()
-
-                for (i, (stream, handler)) in enumerate(streams):
-                    offset = offsets[i]
-                    while len(stream) > offset:
-                        handler(stream[offset])
-                        offset += 1
-                    offsets[i] = offset
+            if self._logging:
+                streams = [
+                    (ps.streams.debug, self._log_record),
+                    (ps.streams.error, self._log_record),
+                    (ps.streams.information, self._log_record),
+                    (ps.streams.progress, self._log_record),
+                    (ps.streams.verbose, self._log_record),
+                    (ps.streams.warning, self._log_record),
+                ]
+                offsets = [0 for _ in streams]
+
+                # We're using polling to make sure output and streams are
+                # handled while the process is running.
+                while ps.state == PSInvocationState.RUNNING:
+                    ps.poll_invoke(timeout=self._operation_timeout)
+
+                    for (i, (stream, handler)) in enumerate(streams):
+                        offset = offsets[i]
+                        while len(stream) > offset:
+                            handler(stream[offset])
+                            offset += 1
+                        offsets[i] = offset
 
             # For good measure, we'll make sure the process has
-            # stopped running.
+            # stopped running in any case.
             ps.end_invoke()
 
+            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
             if ps.streams.error:
                 raise AirflowException("Process had one or more errors")
+        finally:
+            if local_context:
+                self.__exit__(None, None, None)
 
-            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
-            return ps
+    def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: 
Dict[str, str]) -> PowerShell:
+        """Invoke a PowerShell cmdlet and return session."""
+        with self.invoke() as ps:
+            ps.add_cmdlet(name, use_local_scope=use_local_scope)
+            ps.add_parameters(parameters)
+        return ps
 
-    def _log_output(self, message: str):
-        self.log.info("%s", message)
+    def invoke_powershell(self, script: str) -> PowerShell:
+        """Invoke a PowerShell script and return session."""
+        with self.invoke() as ps:
+            ps.add_script(script)
+        return ps
 
     def _log_record(self, record):
-        # TODO: Consider translating some or all of these records into
-        # normal logging levels, using `log(level, msg, *args)`.
-        if isinstance(record, ErrorRecord):
-            self.log.info("Error: %s", record)
-            return
-
-        if isinstance(record, InformationRecord):
-            self.log.info("Information: %s", record.message_data)
-            return
-
-        if isinstance(record, ProgressRecord):
+        message_type = getattr(record, "MESSAGE_TYPE", None)
+
+        # There seems to be a problem with some record types; we'll assume
+        # that the class name matches a message type.
+        if message_type is None:
+            message_type = getattr(
+                MessageType, re.sub('(?!^)([A-Z]+)', r'_\1', 
type(record).__name__).upper()
+            )
+
+        if message_type == MessageType.ERROR_RECORD:
+            self.log.info("%s: %s", record.reason, record)
+            if record.script_stacktrace:
+                for trace in record.script_stacktrace.split('\r\n'):

Review comment:
       `str.splitlines()`

##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):
     """
     Hook for PowerShell Remoting Protocol execution.
 
-    The hook must be used as a context manager.
+    When used as a context manager, the runspace pool is reused between shell
+    sessions.
+
+    :param psrp_conn_id: Required. The name of the PSRP connection.
+    :type psrp_conn_id: str
+    :param logging: If true (default), log command output and streams during 
execution.
+    :type logging: bool
+    :param operation_timeout: Override the default WSMan timeout when polling 
the pipeline.
+    :type operation_timeout: float
+    :param runspace_options:
+        Optional dictionary which is passed when creating the runspace pool. 
See
+        :py:class:`~pypsrp.powershell.RunspacePool` for a description of the
+        available options.
+    :type runspace_options: dict
+    :param wsman_options:
+        Optional dictionary which is passed when creating the `WSMan` client. 
See
+        :py:class:`~pypsrp.wsman.WSMan` for a description of the available 
options.
+    :type wsman_options: dict
+    :param exchange_keys:
+        If true (default), automatically initiate a session key exchange when 
the
+        hook is used as a context manager.
+    :type exchange_keys: bool
+
+    You can provide an alternative `configuration_name` using either 
`runspace_options`
+    or by setting this key as the extra fields of your connection.
     """
 
-    _client = None
-    _poll_interval = 1
-
-    def __init__(self, psrp_conn_id: str):
+    _conn = None
+    _configuration_name = None
+    _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary()
+
+    def __init__(
+        self,
+        psrp_conn_id: str,
+        logging: bool = True,
+        operation_timeout: Optional[float] = None,
+        runspace_options: Optional[Dict[str, Any]] = None,
+        wsman_options: Optional[Dict[str, Any]] = None,
+        exchange_keys: bool = True,
+    ):
         self.conn_id = psrp_conn_id
+        self._logging = logging
+        self._operation_timeout = operation_timeout
+        self._runspace_options = runspace_options or {}
+        self._wsman_options = wsman_options or {}
+        self._exchange_keys = exchange_keys
 
     def __enter__(self):
-        conn = self.get_connection(self.conn_id)
-
-        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
-        self._client = WSMan(
-            conn.host,
-            ssl=True,
-            auth="ntlm",
-            encryption="never",
-            username=conn.login,
-            password=conn.password,
-            cert_validation=False,
-        )
-        self._client.__enter__()
+        conn = self.get_conn()
+        self._wsman_ref[conn].__enter__()
+        conn.__enter__()
+        if self._exchange_keys:
+            conn.exchange_keys()
+        self._conn = conn
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         try:
-            self._client.__exit__(exc_type, exc_value, traceback)
+            self._conn.__exit__(exc_type, exc_value, traceback)
+            self._wsman_ref[self._conn].__exit__(exc_type, exc_value, 
traceback)
         finally:
-            self._client = None
+            del self._conn
 
-    def invoke_powershell(self, script: str) -> PowerShell:
-        with RunspacePool(self._client) as pool:
-            ps = PowerShell(pool)
-            ps.add_script(script)
+    def get_conn(self) -> RunspacePool:
+        """
+        Returns a runspace pool.
+
+        The returned object must be used as a context manager.
+        """
+        conn = self.get_connection(self.conn_id)
+        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
+
+        extra = conn.extra_dejson.copy()
+
+        def apply_extra(d, keys):
+            d = d.copy()
+            for key in keys:
+                value = extra.pop(key, None)
+                if value is not None:
+                    d[key] = value
+            return d
+
+        wsman_options = apply_extra(
+            self._wsman_options,
+            (
+                "auth",
+                "cert_validation",
+                "connection_timeout",
+                "locale",
+                "read_timeout",
+                "reconnection_retries",
+                "reconnection_backoff",
+                "ssl",
+            ),
+        )
+        wsman = WSMan(conn.host, username=conn.login, password=conn.password, 
**wsman_options)
+        runspace_options = apply_extra(self._runspace_options, 
("configuration_name",))
+
+        if extra:
+            raise AirflowException(f"Unexpected extra configuration keys: {', 
'.join(sorted(extra))}")
+        pool = RunspacePool(wsman, **runspace_options)
+        self._wsman_ref[pool] = wsman
+        return pool
+
+    @contextmanager
+    def invoke(self) -> PowerShell:
+        """
+        Context manager that yields a PowerShell object to which commands can 
be
+        added. Upon exit, the commands will be invoked.
+        """
+        local_context = self._conn is None
+        if local_context:
+            self.__enter__()
+        try:
+            ps = PowerShell(self._conn)
+            yield ps
             ps.begin_invoke()
-            streams = [
-                (ps.output, self._log_output),
-                (ps.streams.debug, self._log_record),
-                (ps.streams.information, self._log_record),
-                (ps.streams.error, self._log_record),
-            ]
-            offsets = [0 for _ in streams]
-
-            # We're using polling to make sure output and streams are
-            # handled while the process is running.
-            while ps.state == PSInvocationState.RUNNING:
-                sleep(self._poll_interval)
-                ps.poll_invoke()
-
-                for (i, (stream, handler)) in enumerate(streams):
-                    offset = offsets[i]
-                    while len(stream) > offset:
-                        handler(stream[offset])
-                        offset += 1
-                    offsets[i] = offset
+            if self._logging:
+                streams = [
+                    (ps.streams.debug, self._log_record),
+                    (ps.streams.error, self._log_record),
+                    (ps.streams.information, self._log_record),
+                    (ps.streams.progress, self._log_record),
+                    (ps.streams.verbose, self._log_record),
+                    (ps.streams.warning, self._log_record),
+                ]
+                offsets = [0 for _ in streams]
+
+                # We're using polling to make sure output and streams are
+                # handled while the process is running.
+                while ps.state == PSInvocationState.RUNNING:
+                    ps.poll_invoke(timeout=self._operation_timeout)
+
+                    for (i, (stream, handler)) in enumerate(streams):
+                        offset = offsets[i]
+                        while len(stream) > offset:
+                            handler(stream[offset])
+                            offset += 1
+                        offsets[i] = offset
 
             # For good measure, we'll make sure the process has
-            # stopped running.
+            # stopped running in any case.
             ps.end_invoke()
 
+            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
             if ps.streams.error:
                 raise AirflowException("Process had one or more errors")
+        finally:
+            if local_context:
+                self.__exit__(None, None, None)
 
-            self.log.info("Invocation state: %s", 
str(PSInvocationState(ps.state)))
-            return ps
+    def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: 
Dict[str, str]) -> PowerShell:
+        """Invoke a PowerShell cmdlet and return session."""
+        with self.invoke() as ps:
+            ps.add_cmdlet(name, use_local_scope=use_local_scope)
+            ps.add_parameters(parameters)
+        return ps
 
-    def _log_output(self, message: str):
-        self.log.info("%s", message)
+    def invoke_powershell(self, script: str) -> PowerShell:
+        """Invoke a PowerShell script and return session."""
+        with self.invoke() as ps:
+            ps.add_script(script)
+        return ps
 
     def _log_record(self, record):
-        # TODO: Consider translating some or all of these records into
-        # normal logging levels, using `log(level, msg, *args)`.
-        if isinstance(record, ErrorRecord):
-            self.log.info("Error: %s", record)
-            return
-
-        if isinstance(record, InformationRecord):
-            self.log.info("Information: %s", record.message_data)
-            return
-
-        if isinstance(record, ProgressRecord):
+        message_type = getattr(record, "MESSAGE_TYPE", None)
+
+        # There seems to be a problem with some record types; we'll assume
+        # that the class name matches a message type.
+        if message_type is None:
+            message_type = getattr(
+                MessageType, re.sub('(?!^)([A-Z]+)', r'_\1', 
type(record).__name__).upper()
+            )

Review comment:
       Should we provide some sort of fallback when even this fails? I’m not 
that confident on this class name—message type relation.

##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):

Review comment:
       Hmm we’ve been recently renaming things in the AWS provider, maybe we 
should also take the change to rename this to `PsrpHook` to fit the naming 
converntion (operator as well).

##########
File path: airflow/providers/microsoft/psrp/hooks/psrp.py
##########
@@ -16,103 +16,222 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from time import sleep
+import re
+from contextlib import contextmanager
+from logging import DEBUG, ERROR, INFO, WARNING
+from typing import Any, Dict, Optional
+from weakref import WeakKeyDictionary
 
-from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord
+from pypsrp.messages import MessageType
 from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
 from pypsrp.wsman import WSMan
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+INFORMATIONAL_RECORD_LEVEL_MAP = {
+    MessageType.DEBUG_RECORD: DEBUG,
+    MessageType.ERROR_RECORD: ERROR,
+    MessageType.VERBOSE_RECORD: INFO,
+    MessageType.WARNING_RECORD: WARNING,
+}
+
 
 class PSRPHook(BaseHook):
     """
     Hook for PowerShell Remoting Protocol execution.
 
-    The hook must be used as a context manager.
+    When used as a context manager, the runspace pool is reused between shell
+    sessions.
+
+    :param psrp_conn_id: Required. The name of the PSRP connection.
+    :type psrp_conn_id: str
+    :param logging: If true (default), log command output and streams during 
execution.
+    :type logging: bool
+    :param operation_timeout: Override the default WSMan timeout when polling 
the pipeline.
+    :type operation_timeout: float
+    :param runspace_options:
+        Optional dictionary which is passed when creating the runspace pool. 
See
+        :py:class:`~pypsrp.powershell.RunspacePool` for a description of the
+        available options.
+    :type runspace_options: dict
+    :param wsman_options:
+        Optional dictionary which is passed when creating the `WSMan` client. 
See
+        :py:class:`~pypsrp.wsman.WSMan` for a description of the available 
options.
+    :type wsman_options: dict
+    :param exchange_keys:
+        If true (default), automatically initiate a session key exchange when 
the
+        hook is used as a context manager.
+    :type exchange_keys: bool
+
+    You can provide an alternative `configuration_name` using either 
`runspace_options`
+    or by setting this key as the extra fields of your connection.
     """
 
-    _client = None
-    _poll_interval = 1
-
-    def __init__(self, psrp_conn_id: str):
+    _conn = None
+    _configuration_name = None
+    _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary()
+
+    def __init__(
+        self,
+        psrp_conn_id: str,
+        logging: bool = True,
+        operation_timeout: Optional[float] = None,
+        runspace_options: Optional[Dict[str, Any]] = None,
+        wsman_options: Optional[Dict[str, Any]] = None,
+        exchange_keys: bool = True,
+    ):
         self.conn_id = psrp_conn_id
+        self._logging = logging
+        self._operation_timeout = operation_timeout
+        self._runspace_options = runspace_options or {}
+        self._wsman_options = wsman_options or {}
+        self._exchange_keys = exchange_keys
 
     def __enter__(self):
-        conn = self.get_connection(self.conn_id)
-
-        self.log.info("Establishing WinRM connection %s to host: %s", 
self.conn_id, conn.host)
-        self._client = WSMan(
-            conn.host,
-            ssl=True,
-            auth="ntlm",
-            encryption="never",
-            username=conn.login,
-            password=conn.password,
-            cert_validation=False,
-        )
-        self._client.__enter__()
+        conn = self.get_conn()
+        self._wsman_ref[conn].__enter__()
+        conn.__enter__()
+        if self._exchange_keys:
+            conn.exchange_keys()
+        self._conn = conn
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         try:
-            self._client.__exit__(exc_type, exc_value, traceback)
+            self._conn.__exit__(exc_type, exc_value, traceback)
+            self._wsman_ref[self._conn].__exit__(exc_type, exc_value, 
traceback)
         finally:
-            self._client = None
+            del self._conn

Review comment:
       ```suggestion
               self._conn = None
   ```
   
   It feels unnecessarily complicated to delete the instance-bound variable. An 
extra entry in `__dict__` is not that big a deal. This also prevents the 
class-bound `_conn` gets accidentally deleted if `__exit__` is called without 
`__enter__` (not that it’s likely in any way, but it’s a risk we don’t need to 
take in the first place).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to