details: https://code.tryton.org/tryton/commit/78632c9aa00f
branch: default
user: Nicolas Évrard <[email protected]>
date: Fri Oct 24 18:16:51 2025 +0200
description:
Merge security.check_session and security.check
That way we won't use the pool when the database is not already
initialized.
Closes #14316
diffstat:
trytond/trytond/security.py | 102 +++++++++++++++++---------------
trytond/trytond/tests/test_security.py | 40 ++++++++----
trytond/trytond/tests/test_wsgi.py | 90 ++++++++++++++++++++++++++++-
trytond/trytond/wsgi.py | 9 ++-
4 files changed, 178 insertions(+), 63 deletions(-)
diffs (317 lines):
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/security.py
--- a/trytond/trytond/security.py Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/security.py Fri Oct 24 18:16:51 2025 +0200
@@ -125,30 +125,69 @@
def check(dbname, user, session, context=None):
- for count in range(config.getint('database', 'retry'), -1, -1):
- with Transaction().start(dbname, user, context=context) as transaction:
- pool = _get_pool(dbname)
- Session = pool.get('ir.session')
- try:
- find = Session.check(user, session)
- break
- except backend.DatabaseOperationalError:
- if count:
- continue
- raise
- finally:
- transaction.commit()
+ remote_addr = _get_remote_addr(context)
+
+ database_list = Pool.database_list()
+ if dbname in database_list:
+ for count in range(config.getint('database', 'retry'), -1, -1):
+ with Transaction().start(dbname, user, context=context) \
+ as transaction:
+ pool = Pool(dbname)
+ Session = pool.get('ir.session')
+ try:
+ find = Session.check(user, session)
+ break
+ except backend.DatabaseOperationalError:
+ if count:
+ continue
+ raise
+ finally:
+ transaction.commit()
+ else:
+ if remote_addr:
+ ip_addr = str(ipaddress.ip_address(remote_addr))
+ else:
+ ip_addr = None
+ now = dt.datetime.now()
+ timeout = dt.timedelta(config.getint('session', 'max_age'))
+ database = backend.Database(dbname)
+ conn = database.get_connection(readonly=True)
+ try:
+ ir_session = Table('ir_session')
+ cursor = conn.cursor()
+ session_query = ir_session.select(
+ Coalesce(
+ ir_session.write_date, ir_session.create_date).as_('date'),
+ ir_session.key,
+ where=((ir_session.create_uid == user)
+ & (ir_session.ip_address == ip_addr)))
+ if backend.name == 'sqlite':
+ sqlite_apply_types(session_query, ['DATETIME', None])
+ cursor.execute(*session_query)
+ bad_session = False
+ for session_date, session_key in cursor:
+ if abs(session_date - now) < timeout:
+ if compare_digest(session_key, session):
+ find = session
+ break
+ else:
+ bad_session = True
+ else:
+ find = None if bad_session else ''
+ finally:
+ database.put_connection(conn)
+
if find is None:
logger.error("session failed for '%s' from '%s' on database '%s'",
- user, _get_remote_addr(context), dbname)
+ user, remote_addr, dbname)
return
elif not find:
logger.info("session expired for '%s' from '%s' on database '%s'",
- user, _get_remote_addr(context), dbname)
+ user, remote_addr, dbname)
return
else:
logger.debug("session valid for '%s' from '%s' on database '%s'",
- user, _get_remote_addr(context), dbname)
+ user, remote_addr, dbname)
return user
@@ -172,37 +211,6 @@
return valid
-def check_session(dbname, user, session, remote_addr=None):
- "Check the session without using the pool"
- database = backend.Database(dbname)
- now = dt.datetime.now()
- timeout = dt.timedelta(config.getint('session', 'max_age'))
- conn = database.get_connection(readonly=True)
- if remote_addr:
- ip_addr = str(ipaddress.ip_address(remote_addr))
- else:
- ip_addr = None
- try:
- ir_session = Table('ir_session')
- cursor = conn.cursor()
- session_query = ir_session.select(
- Coalesce(
- ir_session.write_date, ir_session.create_date).as_('date'),
- ir_session.key,
- where=((ir_session.create_uid == user)
- & (ir_session.ip_address == ip_addr)))
- if backend.name == 'sqlite':
- sqlite_apply_types(session_query, ['DATETIME', None])
- cursor.execute(*session_query)
- for session_date, session_key in cursor:
- if abs(session_date - now) < timeout:
- if compare_digest(session_key, session):
- return True
- return False
- finally:
- database.put_connection(conn)
-
-
def reset(dbname, session, context):
try:
with Transaction().start(dbname, 0, context=context, autocommit=True):
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/tests/test_security.py
--- a/trytond/trytond/tests/test_security.py Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/tests/test_security.py Fri Oct 24 18:16:51 2025 +0200
@@ -29,30 +29,44 @@
Transaction().commit()
@with_transaction()
- def test_check_session(self):
- "Testing check_session"
+ def _get_auth(self):
pool = Pool()
User = pool.get('res.user')
Session = pool.get('ir.session')
- db_name = Transaction().database.name
user, = User.search([('login', '=', 'user')])
with Transaction().set_user(user.id):
key = Session.new()
Transaction().commit()
+ return user.id, key
- user_id = security.check_session(db_name, user.id, key)
- self.assertEqual(user_id, user.id)
+ def test_security_check(self):
+ "Test security.check"
+ user_id, key = self._get_auth()
+ authenticated_user_id = security.check(self.db_name, user_id, key)
+ self.assertEqual(authenticated_user_id, user_id)
+
+ def test_security_check_invalid(self):
+ "Test security.check with an invalid session"
+ user_id, _ = self._get_auth()
+ user_id = security.check(self.db_name, user_id, "invalid key")
+ self.assertIsNone(user_id)
- @with_transaction()
- def test_check_session_invalid(self):
- "Testing check_session with an invalid session"
- pool = Pool()
- User = pool.get('res.user')
+ def test_security_check_no_pool(self):
+ "Test security.check without the pool"
+ user_id, key = self._get_auth()
+ Pool.stop(self.db_name)
+
+ authenticated_user_id = security.check(self.db_name, user_id, key)
+ self.assertNotIn(self.db_name, Pool._pools)
+ self.assertEqual(authenticated_user_id, user_id)
- db_name = Transaction().database.name
- user, = User.search([('login', '=', 'user')])
+ def test_security_check_no_pool_invalid(self):
+ "Test security.check without the pool on an invalid session"
+ user_id, _ = self._get_auth()
+ Pool.stop(self.db_name)
- user_id = security.check_session(db_name, user.id, "invalid key")
+ user_id = security.check(self.db_name, user_id, "invalid key")
+ self.assertNotIn(self.db_name, Pool._pools)
self.assertIsNone(user_id)
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/tests/test_wsgi.py
--- a/trytond/trytond/tests/test_wsgi.py Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/tests/test_wsgi.py Fri Oct 24 18:16:51 2025 +0200
@@ -1,13 +1,17 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
+import base64
+from http import HTTPStatus
from unittest.mock import Mock, sentinel
from werkzeug.routing import Map, Rule
+from trytond import security
from trytond.exceptions import TrytonException
+from trytond.pool import Pool
from trytond.protocols.wrappers import Response
-from trytond.tests.test_tryton import Client, TestCase
+from trytond.tests.test_tryton import Client, RouteTestCase, TestCase
from trytond.wsgi import Base64Converter, TrytondWSGI
@@ -133,3 +137,87 @@
self.assertEqual(next(response.response), b'baz')
self.assertEqual(response.status, "418 I'M A TEAPOT")
+
+
+class TrytonWSGITestCase(RouteTestCase):
+ module = 'res'
+
+ @classmethod
+ def setUpDatabase(cls):
+ pool = Pool()
+ User = pool.get('res.user')
+ User.create([{
+ 'name': 'user',
+ 'login': 'user',
+ 'password': '12345678',
+ }])
+
+ def test_session_valid_good_auth(self):
+ "Test that session_valid correctly authenticates"
+ app = TrytondWSGI()
+
+ @app.route('/<database_name>/session_required')
+ @app.session_valid
+ def _route(request, database_name):
+ return Response(b'')
+
+ user_id, key = security.login(
+ self.db_name, 'user', {'password': '12345678'})
+ client = Client(app, Response)
+ session_hdr = 'Session ' + base64.b64encode(
+ f'user:{user_id}:{key}'.encode('utf8')).decode('utf8')
+ response = client.get(
+ f'/{self.db_name}/session_required',
+ headers=[('Authorization', session_hdr)])
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+
+ def test_session_valid_no_pool(self):
+ "Test that session_valid does not use the pool"
+ app = TrytondWSGI()
+
+ @app.route('/<database_name>/session_required')
+ @app.session_valid
+ def _route(request):
+ return Response(b'')
+
+ user_id, key = security.login(
+ self.db_name, 'user', {'password': '12345678'})
+ Pool.stop(self.db_name)
+
+ client = Client(app, Response)
+ session_hdr = 'Session ' + base64.b64encode(
+ f'user:{user_id}:{key}'.encode('utf8')).decode('utf8')
+ client.get(
+ f'/{self.db_name}/session_required',
+ headers=[('Authorization', session_hdr)])
+ self.assertNotIn(self.db_name, Pool._pools)
+
+ def test_session_valid_bad_auth(self):
+ "Test that session_valid refuse wrong Authentication headers"
+
+ app = TrytondWSGI()
+
+ @app.route('/<database_name>/session_required')
+ @app.session_valid
+ def _route(request):
+ return Response(b'')
+
+ client = Client(app, Response)
+ response = client.get(
+ f'/{self.db_name}/session_required',
+ headers=[('Authorization', 'Session bad token')])
+ self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED)
+
+ def test_session_valid_no_auth(self):
+ "Test that session_valid refuse unauthenticated requests"
+
+ app = TrytondWSGI()
+
+ @app.route('/<database_name>/session_required')
+ @app.session_valid
+ def _route(request):
+ return Response(b'')
+
+ client = Client(app, Response)
+ response = client.get(f'/{self.db_name}/session_required')
+ self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED)
diff -r c2839b6f4a73 -r 78632c9aa00f trytond/trytond/wsgi.py
--- a/trytond/trytond/wsgi.py Fri Oct 24 18:13:01 2025 +0200
+++ b/trytond/trytond/wsgi.py Fri Oct 24 18:16:51 2025 +0200
@@ -102,8 +102,13 @@
session = request.authorization.get('session')
dbname = request.view_args.get('database_name')
- if not security.check_session(
- dbname, userid, session, request.remote_addr):
+ session_check = security.check(
+ dbname, userid, session, {
+ '_request': {
+ 'remote_addr': request.remote_addr,
+ },
+ })
+ if session_check is None:
_do_basic_auth(request)
return func(request, *args, **kwargs)