This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b254a9f4be Add oracledb thick mode support for oracle provider (#26576)
b254a9f4be is described below

commit b254a9f4bead4e5d4f74c633446da38550f8e0a1
Author: Paul Williams <[email protected]>
AuthorDate: Wed Sep 28 02:14:46 2022 -0400

    Add oracledb thick mode support for oracle provider (#26576)
---
 airflow/providers/oracle/hooks/oracle.py           |  96 +++++++++++++++++
 airflow/utils/db.py                                |  12 +++
 .../connections/oracle.rst                         |  28 ++---
 tests/providers/oracle/hooks/test_oracle.py        | 120 +++++++++++++++++++++
 4 files changed, 244 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/oracle/hooks/oracle.py 
b/airflow/providers/oracle/hooks/oracle.py
index a66d12dc4d..a3850f9dd1 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -42,12 +42,56 @@ def _map_param(value):
     return value
 
 
+def _get_bool(val):
+    if isinstance(val, bool):
+        return val
+    if isinstance(val, str):
+        val = val.lower().strip()
+        if val == 'true':
+            return True
+        if val == 'false':
+            return False
+    return None
+
+
+def _get_first_bool(*vals):
+    for val in vals:
+        converted = _get_bool(val)
+        if isinstance(converted, bool):
+            return converted
+    return None
+
+
 class OracleHook(DbApiHook):
     """
     Interact with Oracle SQL.
 
     :param oracle_conn_id: The :ref:`Oracle connection id 
<howto/connection:oracle>`
         used for Oracle credentials.
+    :param thick_mode: Specify whether to use python-oracledb in thick mode. 
Defaults to False.
+        If set to True, you must have the Oracle Client libraries installed.
+        See `oracledb 
docs<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html>`
+        for more info.
+    :param thick_mode_lib_dir: Path to use to find the Oracle Client libraries 
when using thick mode.
+        If not specified, defaults to the standard way of locating the Oracle 
Client library on the OS.
+        See `oracledb docs
+        
<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#setting-the-oracle-client-library-directory>`
+        for more info.
+    :param thick_mode_config_dir: Path to use to find the Oracle Client library
+        configuration files when using thick mode.
+        If not specified, defaults to the standard way of locating the Oracle 
Client
+        library configuration files on the OS.
+        See `oracledb docs
+        
<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#optional-oracle-net-configuration-files>`
+        for more info.
+    :param fetch_decimals: Specify whether numbers should be fetched as 
``decimal.Decimal`` values.
+        See `defaults.fetch_decimals
+        
<https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>`
+        for more info.
+    :param fetch_lobs: Specify whether to fetch strings/bytes for CLOBs or 
BLOBs instead of locators.
+        See `defaults.fetch_lobs
+        
<https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>`
+        for more info.
     """
 
     conn_name_attr = 'oracle_conn_id'
@@ -57,6 +101,24 @@ class OracleHook(DbApiHook):
 
     supports_autocommit = True
 
+    def __init__(
+        self,
+        *args,
+        thick_mode: bool | None = None,
+        thick_mode_lib_dir: str | None = None,
+        thick_mode_config_dir: str | None = None,
+        fetch_decimals: bool | None = None,
+        fetch_lobs: bool | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+
+        self.thick_mode = thick_mode
+        self.thick_mode_lib_dir = thick_mode_lib_dir
+        self.thick_mode_config_dir = thick_mode_config_dir
+        self.fetch_decimals = fetch_decimals
+        self.fetch_lobs = fetch_lobs
+
     def get_conn(self) -> oracledb.Connection:
         """
         Returns a oracle connection object
@@ -90,6 +152,40 @@ class OracleHook(DbApiHook):
         mod = conn.extra_dejson.get('module')
         schema = conn.schema
 
+        # Enable oracledb thick mode if thick_mode is set to True
+        # Parameters take precedence over connection config extra
+        # Defaults to use thin mode if not provided in params or connection 
config extra
+        thick_mode = _get_first_bool(self.thick_mode, 
conn.extra_dejson.get('thick_mode'))
+        if thick_mode is True:
+            if self.thick_mode_lib_dir is None:
+                self.thick_mode_lib_dir = 
conn.extra_dejson.get('thick_mode_lib_dir')
+                if not isinstance(self.thick_mode_lib_dir, (str, type(None))):
+                    raise TypeError(
+                        f'thick_mode_lib_dir expected str or None, '
+                        f'got {type(self.thick_mode_lib_dir).__name__}'
+                    )
+            if self.thick_mode_config_dir is None:
+                self.thick_mode_config_dir = 
conn.extra_dejson.get('thick_mode_config_dir')
+                if not isinstance(self.thick_mode_config_dir, (str, 
type(None))):
+                    raise TypeError(
+                        f'thick_mode_config_dir expected str or None, '
+                        f'got {type(self.thick_mode_config_dir).__name__}'
+                    )
+            oracledb.init_oracle_client(
+                lib_dir=self.thick_mode_lib_dir, 
config_dir=self.thick_mode_config_dir
+            )
+
+        # Set oracledb Defaults Attributes if provided
+        # 
(https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html)
+        fetch_decimals = _get_first_bool(self.fetch_decimals, 
conn.extra_dejson.get('fetch_decimals'))
+        if isinstance(fetch_decimals, bool):
+            oracledb.defaults.fetch_decimals = fetch_decimals
+
+        fetch_lobs = _get_first_bool(self.fetch_lobs, 
conn.extra_dejson.get('fetch_lobs'))
+        if isinstance(fetch_lobs, bool):
+            oracledb.defaults.fetch_lobs = fetch_lobs
+
+        # Set up DSN
         service_name = conn.extra_dejson.get('service_name')
         port = conn.port if conn.port else 1521
         if conn.host and sid and not service_name:
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 9e48b34659..a927b8d828 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -421,6 +421,18 @@ def create_default_connections(session: Session = 
NEW_SESSION):
         ),
         session,
     )
+    merge_conn(
+        Connection(
+            conn_id="oracle_default",
+            conn_type="oracle",
+            host="localhost",
+            login="root",
+            password="password",
+            schema="schema",
+            port=1521,
+        ),
+        session,
+    )
     merge_conn(
         Connection(
             conn_id="oss_default",
diff --git a/docs/apache-airflow-providers-oracle/connections/oracle.rst 
b/docs/apache-airflow-providers-oracle/connections/oracle.rst
index 3934b6648e..6575dce464 100644
--- a/docs/apache-airflow-providers-oracle/connections/oracle.rst
+++ b/docs/apache-airflow-providers-oracle/connections/oracle.rst
@@ -42,15 +42,6 @@ Extra (optional)
     Specify the extra parameters (as json dictionary) that can be used in 
Oracle
     connection. The following parameters are supported:
 
-    * ``encoding`` - The encoding to use for regular database strings. If not 
specified,
-      the environment variable ``NLS_LANG`` is used. If the environment 
variable ``NLS_LANG``
-      is not set, ``ASCII`` is used.
-    * ``nencoding`` - The encoding to use for national character set database 
strings.
-      If not specified, the environment variable ``NLS_NCHAR`` is used. If the 
environment
-      variable ``NLS_NCHAR`` is not used, the environment variable 
``NLS_LANG`` is used instead,
-      and if the environment variable ``NLS_LANG`` is not set, ``ASCII`` is 
used.
-    * ``threaded`` - Whether or not Oracle should wrap accesses to connections 
with a mutex.
-      Default value is False.
     * ``events`` - Whether or not to initialize Oracle in events mode.
     * ``mode`` - one of ``sysdba``, ``sysasm``, ``sysoper``, ``sysbkp``, 
``sysdgd``, ``syskmt`` or ``sysrac``
       which are defined at the module level, Default mode is connecting.
@@ -58,6 +49,22 @@ Extra (optional)
       configuration parameter.
     * ``dsn``. Specify a Data Source Name (and ignore Host).
     * ``sid`` or ``service_name``. Use to form DSN instead of Schema.
+    * ``module`` (str) - This write-only attribute sets the module column in 
the v$session table.
+      The maximum length for this string is 48 and if you exceed this length 
you will get ORA-24960.
+    * ``thick_mode`` (bool) - Specify whether to use python-oracledb in thick 
mode. Defaults to False.
+      If set to True, you must have the Oracle Client libraries installed.
+      See `oracledb 
docs<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html>`
 for more info.
+    * ``thick_mode_lib_dir`` (str) - Path to use to find the Oracle Client 
libraries when using thick mode.
+      If not specified, defaults to the standard way of locating the Oracle 
Client library on the OS.
+      See `oracledb 
docs<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#setting-the-oracle-client-library-directory>`
 for more info.
+    * ``thick_mode_config_dir`` (str) - Path to use to find the Oracle Client 
library configuration files when using thick mode.
+      If not specified, defaults to the standard way of locating the Oracle 
Client library configuration files on the OS.
+      See `oracledb 
docs<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#optional-oracle-net-configuration-files>`
 for more info.
+    * ``fetch_decimals`` (bool) - Specify whether numbers should be fetched as 
``decimal.Decimal`` values.
+      See 
`defaults.fetch_decimals<https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>`
 for more info.
+    * ``fetch_lobs`` (bool) - Specify whether to fetch strings/bytes for CLOBs 
or BLOBs instead of locators.
+      See 
`defaults.fetch_lobs<https://python-oracledb.readthedocs.io/en/latest/api_manual/defaults.html#defaults.fetch_decimals>`
 for more info.
+
 
     Connect using `dsn`, Host and `sid`, Host and `service_name`, or only Host 
`(OracleHook.getconn Documentation) 
<https://airflow.apache.org/docs/apache-airflow-providers-oracle/stable/_modules/airflow/providers/oracle/hooks/oracle.html#OracleHook.get_conn>`_.
 
@@ -93,9 +100,6 @@ Extra (optional)
     .. code-block:: json
 
        {
-          "encoding": "UTF-8",
-          "nencoding": "UTF-8",
-          "threaded": false,
           "events": false,
           "mode": "sysdba",
           "purity": "new"
diff --git a/tests/providers/oracle/hooks/test_oracle.py 
b/tests/providers/oracle/hooks/test_oracle.py
index 1243814509..0e6b6ee96f 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -144,6 +144,126 @@ class TestOracleHookConn(unittest.TestCase):
         self.connection.extra = json.dumps({'service_name': 'service_name'})
         assert self.db_hook.get_conn().current_schema == self.connection.schema
 
+    
@mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.init_oracle_client')
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_thick_mode_extra(self, mock_connect, mock_init_client):
+        thick_mode_test = {
+            'thick_mode': True,
+            'thick_mode_lib_dir': '/opt/oracle/instantclient',
+            'thick_mode_config_dir': '/opt/oracle/config',
+        }
+        self.connection.extra = json.dumps(thick_mode_test)
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert mock_init_client.call_count == 1
+        args, kwargs = mock_init_client.call_args
+        assert args == ()
+        assert kwargs['lib_dir'] == thick_mode_test['thick_mode_lib_dir']
+        assert kwargs['config_dir'] == thick_mode_test['thick_mode_config_dir']
+
+    
@mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.init_oracle_client')
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_thick_mode_extra_str(self, mock_connect, mock_init_client):
+        thick_mode_test = {'thick_mode': 'True'}
+        self.connection.extra = json.dumps(thick_mode_test)
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert mock_init_client.call_count == 1
+
+    
@mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.init_oracle_client')
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_thick_mode_params(self, mock_connect, mock_init_client):
+        # Verify params overrides connection config extra
+        thick_mode_test = {
+            'thick_mode': False,
+            'thick_mode_lib_dir': '/opt/oracle/instantclient',
+            'thick_mode_config_dir': '/opt/oracle/config',
+        }
+        self.connection.extra = json.dumps(thick_mode_test)
+        db_hook = OracleHook(thick_mode=True, thick_mode_lib_dir='/test', 
thick_mode_config_dir='/test_conf')
+        db_hook.get_connection = mock.Mock()
+        db_hook.get_connection.return_value = self.connection
+        db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert mock_init_client.call_count == 1
+        args, kwargs = mock_init_client.call_args
+        assert args == ()
+        assert kwargs['lib_dir'] == '/test'
+        assert kwargs['config_dir'] == '/test_conf'
+
+    
@mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.init_oracle_client')
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_thick_mode_defaults_to_false(self, mock_connect, 
mock_init_client):
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert mock_init_client.call_count == 0
+
+    
@mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.init_oracle_client')
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_thick_mode_dirs_defaults(self, mock_connect, mock_init_client):
+        thick_mode_test = {'thick_mode': True}
+        self.connection.extra = json.dumps(thick_mode_test)
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert mock_init_client.call_count == 1
+        args, kwargs = mock_init_client.call_args
+        assert args == ()
+        assert kwargs['lib_dir'] is None
+        assert kwargs['config_dir'] is None
+
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_oracledb_defaults_attributes_default_values(self, mock_connect):
+        default_fetch_decimals = oracledb.defaults.fetch_decimals
+        default_fetch_lobs = oracledb.defaults.fetch_lobs
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        # Check that OracleHook.get_conn() doesn't try to set defaults if not 
provided
+        assert oracledb.defaults.fetch_decimals == default_fetch_decimals
+        assert oracledb.defaults.fetch_lobs == default_fetch_lobs
+
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_oracledb_defaults_attributes_extra(self, mock_connect):
+        defaults_test = {'fetch_decimals': True, 'fetch_lobs': False}
+        self.connection.extra = json.dumps(defaults_test)
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert oracledb.defaults.fetch_decimals == 
defaults_test['fetch_decimals']
+        assert oracledb.defaults.fetch_lobs == defaults_test['fetch_lobs']
+
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_oracledb_defaults_attributes_extra_str(self, mock_connect):
+        defaults_test = {'fetch_decimals': 'True', 'fetch_lobs': 'False'}
+        self.connection.extra = json.dumps(defaults_test)
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert oracledb.defaults.fetch_decimals is True
+        assert oracledb.defaults.fetch_lobs is False
+
+    @mock.patch('airflow.providers.oracle.hooks.oracle.oracledb.connect')
+    def test_set_oracledb_defaults_attributes_params(self, mock_connect):
+        # Verify params overrides connection config extra
+        defaults_test = {'fetch_decimals': False, 'fetch_lobs': True}
+        self.connection.extra = json.dumps(defaults_test)
+        db_hook = OracleHook(fetch_decimals=True, fetch_lobs=False)
+        db_hook.get_connection = mock.Mock()
+        db_hook.get_connection.return_value = self.connection
+        db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        assert oracledb.defaults.fetch_decimals is True
+        assert oracledb.defaults.fetch_lobs is False
+
+    def test_type_checking_thick_mode_lib_dir(self):
+        with pytest.raises(TypeError, match=r"thick_mode_lib_dir expected str 
or None, got.*"):
+            thick_mode_lib_dir_test = {'thick_mode': True, 
'thick_mode_lib_dir': 1}
+            self.connection.extra = json.dumps(thick_mode_lib_dir_test)
+            self.db_hook.get_conn()
+
+    def test_type_checking_thick_mode_config_dir(self):
+        with pytest.raises(TypeError, match=r"thick_mode_config_dir expected 
str or None, got.*"):
+            thick_mode_config_dir_test = {'thick_mode': True, 
'thick_mode_config_dir': 1}
+            self.connection.extra = json.dumps(thick_mode_config_dir_test)
+            self.db_hook.get_conn()
+
 
 @unittest.skipIf(oracledb is None, 'oracledb package not present')
 class TestOracleHook(unittest.TestCase):

Reply via email to