mistercrunch closed pull request #3530: [Feature] enhanced memoized on 
get_sqla_engine and other functions
URL: https://github.com/apache/incubator-superset/pull/3530
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/models/core.py b/superset/models/core.py
index 2c6e8b015f..396db2dc32 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -629,6 +629,8 @@ def get_effective_user(self, url, user_name=None):
                 effective_username = g.user.username
         return effective_username
 
+    @utils.memoized(
+        watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
     def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
         extra = self.get_extra()
         url = make_url(self.sqlalchemy_uri_decrypted)
@@ -662,10 +664,10 @@ def get_sqla_engine(self, schema=None, nullpool=False, 
user_name=None):
         return create_engine(url, **params)
 
     def get_reserved_words(self):
-        return self.get_sqla_engine().dialect.preparer.reserved_words
+        return self.get_dialect().preparer.reserved_words
 
     def get_quoter(self):
-        return self.get_sqla_engine().dialect.identifier_preparer.quote
+        return self.get_dialect().identifier_preparer.quote
 
     def get_df(self, sql, schema):
         sql = sql.strip().strip(';')
@@ -813,6 +815,7 @@ def has_table(self, table):
         return engine.has_table(
             table.table_name, table.schema or None)
 
+    @utils.memoized
     def get_dialect(self):
         sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
         return sqla_url.get_dialect()()
diff --git a/superset/utils.py b/superset/utils.py
index bae330b4af..afe2f419ad 100644
--- a/superset/utils.py
+++ b/superset/utils.py
@@ -91,38 +91,58 @@ def flasher(msg, severity=None):
             logging.info(msg)
 
 
-class memoized(object):  # noqa
+class _memoized(object):  # noqa
     """Decorator that caches a function's return value each time it is called
 
     If called later with the same arguments, the cached value is returned, and
     not re-evaluated.
+
+    Define ``watch`` as a tuple of attribute names if this Decorator
+    should account for instance variable changes.
     """
 
-    def __init__(self, func):
+    def __init__(self, func, watch=()):
         self.func = func
         self.cache = {}
-
-    def __call__(self, *args):
+        self.is_method = False
+        self.watch = watch
+
+    def __call__(self, *args, **kwargs):
+        key = [args, frozenset(kwargs.items())]
+        if self.is_method:
+            key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
+        key = tuple(key)
+        if key in self.cache:
+            return self.cache[key]
         try:
-            return self.cache[args]
-        except KeyError:
-            value = self.func(*args)
-            self.cache[args] = value
+            value = self.func(*args, **kwargs)
+            self.cache[key] = value
             return value
         except TypeError:
             # uncachable -- for instance, passing a list as an argument.
             # Better to not cache than to blow up entirely.
-            return self.func(*args)
+            return self.func(*args, **kwargs)
 
     def __repr__(self):
         """Return the function's docstring."""
         return self.func.__doc__
 
     def __get__(self, obj, objtype):
+        if not self.is_method:
+            self.is_method = True
         """Support instance methods."""
         return functools.partial(self.__call__, obj)
 
 
+def memoized(func=None, watch=None):
+    if func:
+        return _memoized(func)
+    else:
+        def wrapper(f):
+            return _memoized(f, watch)
+        return wrapper
+
+
 def js_string_to_python(item):
     return None if item in ('null', 'undefined') else item
 
diff --git a/tests/utils_tests.py b/tests/utils_tests.py
index f6d1901d12..46d476632b 100644
--- a/tests/utils_tests.py
+++ b/tests/utils_tests.py
@@ -8,7 +8,7 @@
 
 from superset.utils import (
     base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser,
-    JSONEncodedDict, merge_extra_filters, parse_human_timedelta,
+    JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta,
     SupersetException, validate_json, zlib_compress, zlib_decompress_to_string,
 )
 
@@ -219,3 +219,77 @@ def test_validate_json(self):
         invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
         with self.assertRaises(SupersetException):
             validate_json(invalid)
+
+    def test_memoized_on_functions(self):
+        watcher = {'val': 0}
+
+        @memoized
+        def test_function(a, b, c):
+            watcher['val'] += 1
+            return a * b * c
+        result1 = test_function(1, 2, 3)
+        result2 = test_function(1, 2, 3)
+        self.assertEquals(result1, result2)
+        self.assertEquals(watcher['val'], 1)
+
+    def test_memoized_on_methods(self):
+
+        class test_class:
+            def __init__(self, num):
+                self.num = num
+                self.watcher = 0
+
+            @memoized
+            def test_method(self, a, b, c):
+                self.watcher += 1
+                return a * b * c * self.num
+
+        instance = test_class(5)
+        result1 = instance.test_method(1, 2, 3)
+        result2 = instance.test_method(1, 2, 3)
+        self.assertEquals(result1, result2)
+        self.assertEquals(instance.watcher, 1)
+        instance.num = 10
+        self.assertEquals(result2, instance.test_method(1, 2, 3))
+
+    def test_memoized_on_methods_with_watches(self):
+
+        class test_class:
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+                self.watcher = 0
+
+            @memoized(watch=('x', 'y'))
+            def test_method(self, a, b, c):
+                self.watcher += 1
+                return a * b * c * self.x * self.y
+
+        instance = test_class(3, 12)
+        result1 = instance.test_method(1, 2, 3)
+        result2 = instance.test_method(1, 2, 3)
+        self.assertEquals(result1, result2)
+        self.assertEquals(instance.watcher, 1)
+        result3 = instance.test_method(2, 3, 4)
+        self.assertEquals(instance.watcher, 2)
+        result4 = instance.test_method(2, 3, 4)
+        self.assertEquals(instance.watcher, 2)
+        self.assertEquals(result3, result4)
+        self.assertNotEqual(result3, result1)
+        instance.x = 1
+        result5 = instance.test_method(2, 3, 4)
+        self.assertEqual(instance.watcher, 3)
+        self.assertNotEqual(result5, result4)
+        result6 = instance.test_method(2, 3, 4)
+        self.assertEqual(instance.watcher, 3)
+        self.assertEqual(result6, result5)
+        instance.x = 10
+        instance.y = 10
+        result7 = instance.test_method(2, 3, 4)
+        self.assertEqual(instance.watcher, 4)
+        self.assertNotEqual(result7, result6)
+        instance.x = 3
+        instance.y = 12
+        result8 = instance.test_method(1, 2, 3)
+        self.assertEqual(instance.watcher, 4)
+        self.assertEqual(result1, result8)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to