details:   https://code.tryton.org/tryton/commit/78632c9aa00f
branch:    default
user:      Nicolas Évrard <[email protected]>
date:      Fri Oct 24 18:16:51 2025 +0200
description:
        Merge security.check_session and security.check

        That way we won't use the pool when the database is not already 
initialized.

        Closes #14316
diffstat:

 trytond/trytond/security.py            |  102 +++++++++++++++++---------------
 trytond/trytond/tests/test_security.py |   40 ++++++++----
 trytond/trytond/tests/test_wsgi.py     |   90 ++++++++++++++++++++++++++++-
 trytond/trytond/wsgi.py                |    9 ++-
 4 files changed, 178 insertions(+), 63 deletions(-)

diffs (317 lines):

diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/security.py
--- a/trytond/trytond/security.py       Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/security.py       Fri Oct 24 18:16:51 2025 +0200
@@ -125,30 +125,69 @@
 
 
 def check(dbname, user, session, context=None):
-    for count in range(config.getint('database', 'retry'), -1, -1):
-        with Transaction().start(dbname, user, context=context) as transaction:
-            pool = _get_pool(dbname)
-            Session = pool.get('ir.session')
-            try:
-                find = Session.check(user, session)
-                break
-            except backend.DatabaseOperationalError:
-                if count:
-                    continue
-                raise
-            finally:
-                transaction.commit()
+    remote_addr = _get_remote_addr(context)
+
+    database_list = Pool.database_list()
+    if dbname in database_list:
+        for count in range(config.getint('database', 'retry'), -1, -1):
+            with Transaction().start(dbname, user, context=context) \
+                    as transaction:
+                pool = Pool(dbname)
+                Session = pool.get('ir.session')
+                try:
+                    find = Session.check(user, session)
+                    break
+                except backend.DatabaseOperationalError:
+                    if count:
+                        continue
+                    raise
+                finally:
+                    transaction.commit()
+    else:
+        if remote_addr:
+            ip_addr = str(ipaddress.ip_address(remote_addr))
+        else:
+            ip_addr = None
+        now = dt.datetime.now()
+        timeout = dt.timedelta(config.getint('session', 'max_age'))
+        database = backend.Database(dbname)
+        conn = database.get_connection(readonly=True)
+        try:
+            ir_session = Table('ir_session')
+            cursor = conn.cursor()
+            session_query = ir_session.select(
+                Coalesce(
+                    ir_session.write_date, ir_session.create_date).as_('date'),
+                ir_session.key,
+                where=((ir_session.create_uid == user)
+                    & (ir_session.ip_address == ip_addr)))
+            if backend.name == 'sqlite':
+                sqlite_apply_types(session_query, ['DATETIME', None])
+            cursor.execute(*session_query)
+            bad_session = False
+            for session_date, session_key in cursor:
+                if abs(session_date - now) < timeout:
+                    if compare_digest(session_key, session):
+                        find = session
+                        break
+                    else:
+                        bad_session = True
+            else:
+                find = None if bad_session else ''
+        finally:
+            database.put_connection(conn)
+
     if find is None:
         logger.error("session failed for '%s' from '%s' on database '%s'",
-            user, _get_remote_addr(context), dbname)
+            user, remote_addr, dbname)
         return
     elif not find:
         logger.info("session expired for '%s' from '%s' on database '%s'",
-            user, _get_remote_addr(context), dbname)
+            user, remote_addr, dbname)
         return
     else:
         logger.debug("session valid for '%s' from '%s' on database '%s'",
-            user, _get_remote_addr(context), dbname)
+            user, remote_addr, dbname)
         return user
 
 
@@ -172,37 +211,6 @@
     return valid
 
 
-def check_session(dbname, user, session, remote_addr=None):
-    "Check the session without using the pool"
-    database = backend.Database(dbname)
-    now = dt.datetime.now()
-    timeout = dt.timedelta(config.getint('session', 'max_age'))
-    conn = database.get_connection(readonly=True)
-    if remote_addr:
-        ip_addr = str(ipaddress.ip_address(remote_addr))
-    else:
-        ip_addr = None
-    try:
-        ir_session = Table('ir_session')
-        cursor = conn.cursor()
-        session_query = ir_session.select(
-            Coalesce(
-                ir_session.write_date, ir_session.create_date).as_('date'),
-            ir_session.key,
-            where=((ir_session.create_uid == user)
-                & (ir_session.ip_address == ip_addr)))
-        if backend.name == 'sqlite':
-            sqlite_apply_types(session_query, ['DATETIME', None])
-        cursor.execute(*session_query)
-        for session_date, session_key in cursor:
-            if abs(session_date - now) < timeout:
-                if compare_digest(session_key, session):
-                    return True
-        return False
-    finally:
-        database.put_connection(conn)
-
-
 def reset(dbname, session, context):
     try:
         with Transaction().start(dbname, 0, context=context, autocommit=True):
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/tests/test_security.py
--- a/trytond/trytond/tests/test_security.py    Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/tests/test_security.py    Fri Oct 24 18:16:51 2025 +0200
@@ -29,30 +29,44 @@
         Transaction().commit()
 
     @with_transaction()
-    def test_check_session(self):
-        "Testing check_session"
+    def _get_auth(self):
         pool = Pool()
         User = pool.get('res.user')
         Session = pool.get('ir.session')
 
-        db_name = Transaction().database.name
         user, = User.search([('login', '=', 'user')])
         with Transaction().set_user(user.id):
             key = Session.new()
 
         Transaction().commit()
+        return user.id, key
 
-        user_id = security.check_session(db_name, user.id, key)
-        self.assertEqual(user_id, user.id)
+    def test_security_check(self):
+        "Test security.check"
+        user_id, key = self._get_auth()
+        authenticated_user_id = security.check(self.db_name, user_id, key)
+        self.assertEqual(authenticated_user_id, user_id)
+
+    def test_security_check_invalid(self):
+        "Test security.check with an invalid session"
+        user_id, _ = self._get_auth()
+        user_id = security.check(self.db_name, user_id, "invalid key")
+        self.assertIsNone(user_id)
 
-    @with_transaction()
-    def test_check_session_invalid(self):
-        "Testing check_session with an invalid session"
-        pool = Pool()
-        User = pool.get('res.user')
+    def test_security_check_no_pool(self):
+        "Test security.check without the pool"
+        user_id, key = self._get_auth()
+        Pool.stop(self.db_name)
+
+        authenticated_user_id = security.check(self.db_name, user_id, key)
+        self.assertNotIn(self.db_name, Pool._pools)
+        self.assertEqual(authenticated_user_id, user_id)
 
-        db_name = Transaction().database.name
-        user, = User.search([('login', '=', 'user')])
+    def test_security_check_no_pool_invalid(self):
+        "Test security.check without the pool on an invalid session"
+        user_id, _ = self._get_auth()
+        Pool.stop(self.db_name)
 
-        user_id = security.check_session(db_name, user.id, "invalid key")
+        user_id = security.check(self.db_name, user_id, "invalid key")
+        self.assertNotIn(self.db_name, Pool._pools)
         self.assertIsNone(user_id)
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/tests/test_wsgi.py
--- a/trytond/trytond/tests/test_wsgi.py        Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/tests/test_wsgi.py        Fri Oct 24 18:16:51 2025 +0200
@@ -1,13 +1,17 @@
 # This file is part of Tryton.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
 
+import base64
+from http import HTTPStatus
 from unittest.mock import Mock, sentinel
 
 from werkzeug.routing import Map, Rule
 
+from trytond import security
 from trytond.exceptions import TrytonException
+from trytond.pool import Pool
 from trytond.protocols.wrappers import Response
-from trytond.tests.test_tryton import Client, TestCase
+from trytond.tests.test_tryton import Client, RouteTestCase, TestCase
 from trytond.wsgi import Base64Converter, TrytondWSGI
 
 
@@ -133,3 +137,87 @@
 
         self.assertEqual(next(response.response), b'baz')
         self.assertEqual(response.status, "418 I'M A TEAPOT")
+
+
+class TrytonWSGITestCase(RouteTestCase):
+    module = 'res'
+
+    @classmethod
+    def setUpDatabase(cls):
+        pool = Pool()
+        User = pool.get('res.user')
+        User.create([{
+                    'name': 'user',
+                    'login': 'user',
+                    'password': '12345678',
+                    }])
+
+    def test_session_valid_good_auth(self):
+        "Test that session_valid correctly authenticates"
+        app = TrytondWSGI()
+
+        @app.route('/<database_name>/session_required')
+        @app.session_valid
+        def _route(request, database_name):
+            return Response(b'')
+
+        user_id, key = security.login(
+            self.db_name, 'user', {'password': '12345678'})
+        client = Client(app, Response)
+        session_hdr = 'Session ' + base64.b64encode(
+            f'user:{user_id}:{key}'.encode('utf8')).decode('utf8')
+        response = client.get(
+            f'/{self.db_name}/session_required',
+            headers=[('Authorization', session_hdr)])
+        self.assertEqual(response.status_code, HTTPStatus.OK)
+
+    def test_session_valid_no_pool(self):
+        "Test that session_valid does not use the pool"
+        app = TrytondWSGI()
+
+        @app.route('/<database_name>/session_required')
+        @app.session_valid
+        def _route(request):
+            return Response(b'')
+
+        user_id, key = security.login(
+            self.db_name, 'user', {'password': '12345678'})
+        Pool.stop(self.db_name)
+
+        client = Client(app, Response)
+        session_hdr = 'Session ' + base64.b64encode(
+            f'user:{user_id}:{key}'.encode('utf8')).decode('utf8')
+        client.get(
+            f'/{self.db_name}/session_required',
+            headers=[('Authorization', session_hdr)])
+        self.assertNotIn(self.db_name, Pool._pools)
+
+    def test_session_valid_bad_auth(self):
+        "Test that session_valid refuse wrong Authentication headers"
+
+        app = TrytondWSGI()
+
+        @app.route('/<database_name>/session_required')
+        @app.session_valid
+        def _route(request):
+            return Response(b'')
+
+        client = Client(app, Response)
+        response = client.get(
+            f'/{self.db_name}/session_required',
+            headers=[('Authorization', 'Session bad token')])
+        self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED)
+
+    def test_session_valid_no_auth(self):
+        "Test that session_valid refuse unauthenticated requests"
+
+        app = TrytondWSGI()
+
+        @app.route('/<database_name>/session_required')
+        @app.session_valid
+        def _route(request):
+            return Response(b'')
+
+        client = Client(app, Response)
+        response = client.get(f'/{self.db_name}/session_required')
+        self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED)
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/wsgi.py
--- a/trytond/trytond/wsgi.py   Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/wsgi.py   Fri Oct 24 18:16:51 2025 +0200
@@ -102,8 +102,13 @@
             session = request.authorization.get('session')
             dbname = request.view_args.get('database_name')
 
-            if not security.check_session(
-                    dbname, userid, session, request.remote_addr):
+            session_check = security.check(
+                dbname, userid, session, {
+                    '_request': {
+                        'remote_addr': request.remote_addr,
+                        },
+                    })
+            if session_check is None:
                 _do_basic_auth(request)
 
             return func(request, *args, **kwargs)

Reply via email to