details:   https://code.tryton.org/python-sql/commit/c6cb639096fb
branch:    default
user:      Cédric Krier <[email protected]>
date:      Mon Nov 24 21:35:33 2025 +0100
description:
        Do not use parameter for EXTRACT field

        Closes #97
diffstat:

 CHANGELOG                   |   1 +
 sql/functions.py            |  66 ++++++++++++++++++++++++++++++++++++++++++--
 sql/tests/test_functions.py |  34 ++++++++++++++++++++++-
 3 files changed, 97 insertions(+), 4 deletions(-)

diffs (149 lines):

diff -r 577584c0025c -r c6cb639096fb CHANGELOG
--- a/CHANGELOG Mon Nov 24 17:34:12 2025 +0100
+++ b/CHANGELOG Mon Nov 24 21:35:33 2025 +0100
@@ -1,3 +1,4 @@
+* Do not use parameter for EXTRACT field
 * Remove support for Python older than 3.6
 
 Version 1.7.0 - 2025-11-24
diff -r 577584c0025c -r c6cb639096fb sql/functions.py
--- a/sql/functions.py  Mon Nov 24 17:34:12 2025 +0100
+++ b/sql/functions.py  Mon Nov 24 21:35:33 2025 +0100
@@ -1,5 +1,7 @@
 # This file is part of python-sql.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
+
+from enum import Enum, auto
 from itertools import chain
 
 from sql import CombiningQuery, Expression, Flavor, FromItem, Select, Window
@@ -85,7 +87,7 @@
         return (self._function + '('
             + ' '.join(chain(*zip(
                         self._keywords,
-                        map(self._format, self.args))))[1:]
+                        map(self._format, self.args)))).strip()
             + ')')
 
 
@@ -383,9 +385,67 @@
 
 
 class Extract(FunctionKeyword):
-    __slots__ = ()
+    __slots__ = ('_field',)
     _function = 'EXTRACT'
-    _keywords = ('', 'FROM')
+
+    class Fields(str, Enum):
+        def _generate_next_value_(name, start, count, last_values):
+            return name.upper()
+
+        CENTURY = auto()
+        DAY = auto()
+        DECADE = auto()
+        DOW = auto()
+        DOY = auto()
+        EPOCH = auto()
+        HOUR = auto()
+        ISODOW = auto()
+        ISOYEAR = auto()
+        JULIAN = auto()
+        MICROSECONDS = auto()
+        MILLENNIUM = auto()
+        MILLISECONDS = auto()
+        MINUTE = auto()
+        MONTH = auto()
+        QUARTER = auto()
+        SECOND = auto()
+        TIMEZONE = auto()
+        TIMEZONE_HOUR = auto()
+        TIMEZONE_MINUTE = auto()
+        WEEK = auto()
+        YEAR = auto()
+
+    def __init__(self, field, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.field = field
+
+    @property
+    def field(self):
+        return self._field
+
+    @field.setter
+    def field(self, value):
+        value = value.upper()
+        if not hasattr(self.Fields, value):
+            raise ValueError("invalid field: %r" % value)
+        self._field = value
+
+    @property
+    def _keywords(self):
+        return ('%s FROM' % self.field,)
+
+    def __str__(self):
+        Mapping = Flavor.get().function_mapping.get(self.__class__)
+        if Mapping:
+            return str(Mapping(self.field, *self.args))
+        return super().__str__()
+
+    @property
+    def params(self):
+        Mapping = Flavor.get().function_mapping.get(self.__class__)
+        if Mapping:
+            return Mapping(self.field, *self.args).params
+        return super().params
 
 
 class Isfinite(Function):
diff -r 577584c0025c -r c6cb639096fb sql/tests/test_functions.py
--- a/sql/tests/test_functions.py       Mon Nov 24 17:34:12 2025 +0100
+++ b/sql/tests/test_functions.py       Mon Nov 24 21:35:33 2025 +0100
@@ -4,7 +4,7 @@
 
 from sql import AliasManager, Flavor, Table, Window
 from sql.functions import (
-    Abs, AtTimeZone, CurrentTime, Div, Function, FunctionKeyword,
+    Abs, AtTimeZone, CurrentTime, Div, Extract, Function, FunctionKeyword,
     FunctionNotCallable, Overlay, Rank, Trim, WindowFunction)
 
 
@@ -139,6 +139,38 @@
         self.assertEqual(str(current_time), 'CURRENT_TIME')
         self.assertEqual(current_time.params, ())
 
+    def test_extract(self):
+        extract = Extract(Extract.Fields.DAY, self.table.c)
+        self.assertEqual(str(extract), 'EXTRACT(DAY FROM "c")')
+        self.assertEqual(extract.params, ())
+
+        extract = Extract('day', self.table.c)
+        self.assertEqual(str(extract), 'EXTRACT(DAY FROM "c")')
+        self.assertEqual(extract.params, ())
+
+        extract = Extract(Extract.Fields.DAY, '2000-01-01')
+        self.assertEqual(str(extract), 'EXTRACT(DAY FROM %s)')
+        self.assertEqual(extract.params, ('2000-01-01',))
+
+    def test_extract_mapping(self):
+        class MyExtract(Function):
+            _function = 'MY_EXTRACT'
+
+        extract = Extract(Extract.Fields.DAY, '2000-01-01')
+        flavor = Flavor(function_mapping={
+                Extract: MyExtract,
+                })
+        Flavor.set(flavor)
+        try:
+            self.assertEqual(str(extract), 'MY_EXTRACT(%s, %s)')
+            self.assertEqual(extract.params, ('DAY', '2000-01-01'))
+        finally:
+            Flavor.set(Flavor())
+
+    def test_extract_invalid_field(self):
+        with self.assertRaises(ValueError):
+            Extract('foo', self.table.c)
+
 
 class TestWindowFunction(unittest.TestCase):
 

Reply via email to