changeset 47c55a8f9db6 in trytond:default
details: https://hg.tryton.org/trytond?cmd=changeset&node=47c55a8f9db6
description:
        Use UNION for 'Or'-ed domain with subqueries

        issue10658
        review365781002
diffstat:

 CHANGELOG                      |    1 +
 trytond/model/modelsql.py      |  156 ++++++++++++++++++++++++++++------------
 trytond/tests/test_modelsql.py |   92 ++++++++++++++++++++++++
 3 files changed, 203 insertions(+), 46 deletions(-)

diffs (319 lines):

diff -r 7a211140f570 -r 47c55a8f9db6 CHANGELOG
--- a/CHANGELOG Tue Sep 14 21:49:26 2021 +0200
+++ b/CHANGELOG Wed Sep 15 15:16:40 2021 +0200
@@ -1,3 +1,4 @@
+* Use UNION for 'Or'-ed domain with subqueries
 * Add remove_forbidden_chars in tools
 * Manage errors during non-interactive operations
 * Add estimation count to ModelStorage
diff -r 7a211140f570 -r 47c55a8f9db6 trytond/model/modelsql.py
--- a/trytond/model/modelsql.py Tue Sep 14 21:49:26 2021 +0200
+++ b/trytond/model/modelsql.py Wed Sep 15 15:16:40 2021 +0200
@@ -6,7 +6,7 @@
 from functools import wraps
 
 from sql import (Table, Column, Literal, Desc, Asc, Expression, Null,
-    NullsFirst, NullsLast, For)
+    NullsFirst, NullsLast, For, Union)
 from sql.functions import CurrentTimestamp, Extract
 from sql.conditionals import Coalesce
 from sql.operators import Or, And, Operator, Equal
@@ -1257,20 +1257,70 @@
             raise AccessError(msg)
 
     @classmethod
-    def search(cls, domain, offset=0, limit=None, order=None, count=False,
-            query=False):
+    def __search_query(cls, domain, count, query):
         pool = Pool()
         Rule = pool.get('ir.rule')
-        transaction = Transaction()
-        cursor = transaction.connection.cursor()
+
+        rule_domain = Rule.domain_get(cls.__name__, mode='read')
+        if domain and domain[0] == 'OR':
+            local_domains, subquery_domains = split_subquery_domain(domain)
+        else:
+            local_domains, subquery_domains = None, None
 
-        super(ModelSQL, cls).search(
-            domain, offset=offset, limit=limit, order=order, count=count)
+        # In case the search uses subqueries it's more efficient to use a UNION
+        # of queries than using clauses with some JOIN because databases can
+        # used indexes
+        if subquery_domains:
+            union_tables = []
+            for sub_domain in [['OR'] + local_domains] + subquery_domains:
+                tables, expression = cls.search_domain(sub_domain)
+                if rule_domain:
+                    tables, domain_exp = cls.search_domain(
+                        rule_domain, active_test=False, tables=tables)
+                    expression &= domain_exp
+                main_table, _ = tables[None]
+                table = convert_from(None, tables)
+                columns = cls.__searched_columns(
+                    main_table, not count and not query)
+                union_tables.append(table.select(
+                        *columns, where=expression))
+            expression = None
+            tables = {
+                None: (Union(*union_tables, all_=False), None),
+                }
+        else:
+            tables, expression = cls.search_domain(domain)
+            if rule_domain:
+                tables, domain_exp = cls.search_domain(
+                    rule_domain, active_test=False, tables=tables)
+                expression &= domain_exp
 
-        # Get domain clauses
-        tables, expression = cls.search_domain(domain)
+        return tables, expression
 
-        # Get order by
+    @classmethod
+    def __searched_columns(cls, table, eager=False, history=False):
+        columns = [table.id.as_('id')]
+        if (cls._history and Transaction().context.get('_datetime')
+                and (eager or history)):
+            columns.append(
+                Coalesce(table.write_date, table.create_date).as_('_datetime'))
+            columns.append(Column(table, '__id').as_('__id'))
+        if eager:
+            columns += [f.sql_column(table).as_(n)
+                for n, f in cls._fields.items()
+                if not hasattr(f, 'get')
+                and n != 'id'
+                and not getattr(f, 'translate', False)
+                and f.loading == 'eager']
+            if not callable(cls.table_query):
+                sql_type = fields.Char('timestamp').sql_type().base
+                columns += [Extract('EPOCH',
+                        Coalesce(table.write_date, table.create_date)
+                        ).cast(sql_type).as_('_timestamp')]
+        return columns
+
+    @classmethod
+    def __search_order(cls, order, tables):
         order_by = []
         order_types = {
             'DESC': Desc,
@@ -1299,42 +1349,34 @@
             forder = field.convert_order(oexpr, tables, cls)
             order_by.extend((NullOrdering(Order(o)) for o in forder))
 
-        # construct a clause for the rules :
-        domain = Rule.domain_get(cls.__name__, mode='read')
-        if domain:
-            tables, dom_exp = cls.search_domain(
-                domain, active_test=False, tables=tables)
-            expression &= dom_exp
+        return order_by
+
+    @classmethod
+    def search(cls, domain, offset=0, limit=None, order=None, count=False,
+            query=False):
+        transaction = Transaction()
+        cursor = transaction.connection.cursor()
+
+        super(ModelSQL, cls).search(
+            domain, offset=offset, limit=limit, order=order, count=count)
+
+        tables, expression = cls.__search_query(domain, count, query)
 
         main_table, _ = tables[None]
-        table = convert_from(None, tables)
-
         if count:
+            table = convert_from(None, tables)
             cursor.execute(*table.select(Count(Literal('*')),
                     where=expression, limit=limit, offset=offset))
             return cursor.fetchone()[0]
-        # execute the "main" query to fetch the ids we were searching for
-        columns = [main_table.id.as_('id')]
-        if (cls._history and transaction.context.get('_datetime')
-                and not query):
-            columns.append(Coalesce(
-                    main_table.write_date,
-                    main_table.create_date).as_('_datetime'))
-            columns.append(Column(main_table, '__id').as_('__id'))
-        if not query:
-            columns += [f.sql_column(main_table).as_(n)
-                for n, f in cls._fields.items()
-                if not hasattr(f, 'get')
-                and n != 'id'
-                and not getattr(f, 'translate', False)
-                and f.loading == 'eager']
-            if not callable(cls.table_query):
-                sql_type = fields.Char('timestamp').sql_type().base
-                columns += [Extract('EPOCH',
-                        Coalesce(main_table.write_date, main_table.create_date)
-                        ).cast(sql_type).as_('_timestamp')]
-        select = table.select(*columns,
-            where=expression, order_by=order_by, limit=limit, offset=offset)
+
+        order_by = cls.__search_order(order, tables)
+        # compute it here because __search_order might modify tables
+        table = convert_from(None, tables)
+        columns = cls.__searched_columns(main_table, not query)
+        select = table.select(
+            *columns, where=expression, limit=limit, offset=offset,
+            order_by=order_by)
+
         if query:
             return select
         cursor.execute(*select)
@@ -1407,12 +1449,7 @@
                 cache[cls.__name__][data['id']]._update(data)
 
         if len(rows) >= transaction.database.IN_MAX:
-            if (cls._history
-                    and transaction.context.get('_datetime')
-                    and not query):
-                columns = columns[:3]
-            else:
-                columns = columns[:1]
+            columns = cls.__searched_columns(main_table, history=True)
             cursor.execute(*table.select(*columns,
                     where=expression, order_by=order_by,
                     limit=limit, offset=offset))
@@ -1643,3 +1680,30 @@
             continue
         table = convert_from(table, sub_tables)
     return table
+
+
+def split_subquery_domain(domain):
+    """
+    Split a domain in two parts:
+        - the first one contains all the sub-domains with only local fields
+        - the second one contains all the sub-domains using a related field
+    The main operator of the domain will be stripped from the results.
+    """
+    local_domains, subquery_domains = [], []
+    for sub_domain in domain:
+        if is_leaf(sub_domain):
+            if '.' in sub_domain[0]:
+                subquery_domains.append(sub_domain)
+            else:
+                local_domains.append(sub_domain)
+        elif (not sub_domain or list(sub_domain) in [['OR'], ['AND']]
+                or sub_domain in ['OR', 'AND']):
+            continue
+        else:
+            sub_ldomains, sub_sqdomains = split_subquery_domain(sub_domain)
+            if sub_sqdomains:
+                subquery_domains.append(sub_domain)
+            else:
+                local_domains.append(sub_domain)
+
+    return local_domains, subquery_domains
diff -r 7a211140f570 -r 47c55a8f9db6 trytond/tests/test_modelsql.py
--- a/trytond/tests/test_modelsql.py    Tue Sep 14 21:49:26 2021 +0200
+++ b/trytond/tests/test_modelsql.py    Wed Sep 15 15:16:40 2021 +0200
@@ -11,6 +11,7 @@
 from trytond.exceptions import ConcurrencyException
 from trytond.model.exceptions import (
     RequiredValidationError, SQLConstraintError)
+from trytond.model.modelsql import split_subquery_domain
 from trytond.transaction import Transaction
 from trytond.pool import Pool
 from trytond.tests.test_tryton import activate_module, with_transaction
@@ -892,6 +893,97 @@
         self.assertEqual(cache[record.id]['name'], "Foo")
         self.assertNotIn('_timestamp', cache[record.id])
 
+    @with_transaction()
+    def test_search_or_to_union(self):
+        """
+        Test searching for 'OR'-ed domain
+        """
+        pool = Pool()
+        Model = pool.get('test.modelsql.read')
+
+        Model.create([{
+                    'name': 'A',
+                    }, {
+                    'name': 'B',
+                    }, {
+                    'name': 'C',
+                    'targets': [('create', [{
+                                    'name': 'C.A',
+                                    }]),
+                        ],
+                    }])
+
+        domain = ['OR',
+            ('name', 'ilike', '%A%'),
+            ('targets.name', 'ilike', '%A'),
+            ]
+        with patch('trytond.model.modelsql.split_subquery_domain') as no_split:
+            # Mocking in order not to trigger the split
+            no_split.side_effect = lambda d: (d, [])
+            result_without_split = Model.search(domain)
+        self.assertEqual(
+            Model.search(domain),
+            result_without_split)
+
+    def test_split_subquery_domain_empty(self):
+        """
+        Test the split of domains in local and relation parts (empty domain)
+        """
+        local, related = split_subquery_domain([])
+        self.assertEqual(local, [])
+        self.assertEqual(related, [])
+
+    def test_split_subquery_domain_simple(self):
+        """
+        Test the split of domains in local and relation parts (simple domain)
+        """
+        local, related = split_subquery_domain([('a', '=', 1)])
+        self.assertEqual(local, [('a', '=', 1)])
+        self.assertEqual(related, [])
+
+    def test_split_subquery_domain_dotter(self):
+        """
+        Test the split of domains in local and relation parts (dotted domain)
+        """
+        local, related = split_subquery_domain([('a.b', '=', 1)])
+        self.assertEqual(local, [])
+        self.assertEqual(related, [('a.b', '=', 1)])
+
+    def test_split_subquery_domain_mixed(self):
+        """
+        Test the split of domains in local and relation parts (mixed domains)
+        """
+        local, related = split_subquery_domain(
+            [('a', '=', 1), ('b.c', '=', 2)])
+        self.assertEqual(local, [('a', '=', 1)])
+        self.assertEqual(related, [('b.c', '=', 2)])
+
+    def test_split_subquery_domain_operator(self):
+        """
+        Test the split of domains in local and relation parts (with operator)
+        """
+        local, related = split_subquery_domain(
+            ['OR', ('a', '=', 1), ('b.c', '=', 2)])
+        self.assertEqual(local, [('a', '=', 1)])
+        self.assertEqual(related, [('b.c', '=', 2)])
+
+    def test_split_subquery_domain_nested(self):
+        """
+        Test the split of domains in local and relation parts (nested domains)
+        """
+        local, related = split_subquery_domain(
+            [
+                ['AND', ('a', '=', 1), ('b', '=', 2)],
+                ['AND',
+                    ('b', '=', 2),
+                    ['OR', ('c', '=', 3), ('d.e', '=', 4)]]])
+        self.assertEqual(local, [['AND', ('a', '=', 1), ('b', '=', 2)]])
+        self.assertEqual(related, [
+                ['AND',
+                    ('b', '=', 2),
+                    ['OR', ('c', '=', 3), ('d.e', '=', 4)]]
+                ])
+
 
 def suite():
     suite_ = unittest.TestSuite()

Reply via email to