This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new bece2ea3e4 chore: Remove unnecessary autoflush from tagging and 
key/value workflows (#26009)
bece2ea3e4 is described below

commit bece2ea3e4b9979f3f45f63aa490f499095c7078
Author: John Bodley <[email protected]>
AuthorDate: Thu Nov 16 19:04:04 2023 -0800

    chore: Remove unnecessary autoflush from tagging and key/value workflows 
(#26009)
---
 superset/key_value/commands/delete.py              |  8 +----
 superset/key_value/commands/get.py                 |  7 +---
 superset/key_value/commands/update.py              |  5 +--
 superset/key_value/commands/upsert.py              |  5 +--
 superset/tags/models.py                            | 38 ++++++----------------
 .../key_value/commands/create_test.py              | 12 ++-----
 .../key_value/commands/update_test.py              |  6 ++--
 .../key_value/commands/upsert_test.py              |  8 ++---
 8 files changed, 21 insertions(+), 68 deletions(-)

diff --git a/superset/key_value/commands/delete.py 
b/superset/key_value/commands/delete.py
index b3cf84be07..8b9095c09c 100644
--- a/superset/key_value/commands/delete.py
+++ b/superset/key_value/commands/delete.py
@@ -57,13 +57,7 @@ class DeleteKeyValueCommand(BaseCommand):
 
     def delete(self) -> bool:
         filter_ = get_filter(self.resource, self.key)
-        entry = (
-            db.session.query(KeyValueEntry)
-            .filter_by(**filter_)
-            .autoflush(False)
-            .first()
-        )
-        if entry:
+        if entry := 
db.session.query(KeyValueEntry).filter_by(**filter_).first():
             db.session.delete(entry)
             db.session.commit()
             return True
diff --git a/superset/key_value/commands/get.py 
b/superset/key_value/commands/get.py
index 9d659f3bc7..8a7a250f1c 100644
--- a/superset/key_value/commands/get.py
+++ b/superset/key_value/commands/get.py
@@ -66,12 +66,7 @@ class GetKeyValueCommand(BaseCommand):
 
     def get(self) -> Optional[Any]:
         filter_ = get_filter(self.resource, self.key)
-        entry = (
-            db.session.query(KeyValueEntry)
-            .filter_by(**filter_)
-            .autoflush(False)
-            .first()
-        )
+        entry = db.session.query(KeyValueEntry).filter_by(**filter_).first()
         if entry and (entry.expires_on is None or entry.expires_on > 
datetime.now()):
             return self.codec.decode(entry.value)
         return None
diff --git a/superset/key_value/commands/update.py 
b/superset/key_value/commands/update.py
index becd6d9ca8..4bcd496243 100644
--- a/superset/key_value/commands/update.py
+++ b/superset/key_value/commands/update.py
@@ -77,10 +77,7 @@ class UpdateKeyValueCommand(BaseCommand):
     def update(self) -> Optional[Key]:
         filter_ = get_filter(self.resource, self.key)
         entry: KeyValueEntry = (
-            db.session.query(KeyValueEntry)
-            .filter_by(**filter_)
-            .autoflush(False)
-            .first()
+            db.session.query(KeyValueEntry).filter_by(**filter_).first()
         )
         if entry:
             entry.value = self.codec.encode(self.value)
diff --git a/superset/key_value/commands/upsert.py 
b/superset/key_value/commands/upsert.py
index c5668f1161..9a4092c002 100644
--- a/superset/key_value/commands/upsert.py
+++ b/superset/key_value/commands/upsert.py
@@ -81,10 +81,7 @@ class UpsertKeyValueCommand(BaseCommand):
     def upsert(self) -> Key:
         filter_ = get_filter(self.resource, self.key)
         entry: KeyValueEntry = (
-            db.session.query(KeyValueEntry)
-            .filter_by(**filter_)
-            .autoflush(False)
-            .first()
+            db.session.query(KeyValueEntry).filter_by(**filter_).first()
         )
         if entry:
             entry.value = self.codec.encode(self.value)
diff --git a/superset/tags/models.py b/superset/tags/models.py
index a469c7a33d..7a77677a36 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -20,9 +20,9 @@ import enum
 from typing import TYPE_CHECKING
 
 from flask_appbuilder import Model
-from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Table, Text
+from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, 
Text
 from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship, Session, sessionmaker
+from sqlalchemy.orm import relationship, sessionmaker
 from sqlalchemy.orm.mapper import Mapper
 
 from superset import security_manager
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
     from superset.models.slice import Slice
     from superset.models.sql_lab import Query
 
-Session = sessionmaker(autoflush=False)
+Session = sessionmaker()
 
 user_favorite_tag_table = Table(
     "user_favorite_tag",
@@ -111,7 +111,7 @@ class TaggedObject(Model, AuditMixinNullable):
     tag = relationship("Tag", back_populates="objects", overlaps="tags")
 
 
-def get_tag(name: str, session: Session, type_: TagType) -> Tag:
+def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
     tag_name = name.strip()
     tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
     if tag is None:
@@ -148,7 +148,7 @@ class ObjectUpdater:
     @classmethod
     def _add_owners(
         cls,
-        session: Session,
+        session: orm.Session,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
         for owner_id in cls.get_owners_ids(target):
@@ -166,9 +166,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        session = Session(bind=connection)
-
-        try:
+        with Session(bind=connection) as session:
             # add `owner:` tags
             cls._add_owners(session, target)
 
@@ -179,8 +177,6 @@ class ObjectUpdater:
             )
             session.add(tagged_object)
             session.commit()
-        finally:
-            session.close()
 
     @classmethod
     def after_update(
@@ -189,9 +185,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        session = Session(bind=connection)
-
-        try:
+        with Session(bind=connection) as session:
             # delete current `owner:` tags
             query = (
                 session.query(TaggedObject.id)
@@ -210,8 +204,6 @@ class ObjectUpdater:
             # add `owner:` tags
             cls._add_owners(session, target)
             session.commit()
-        finally:
-            session.close()
 
     @classmethod
     def after_delete(
@@ -220,9 +212,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        session = Session(bind=connection)
-
-        try:
+        with Session(bind=connection) as session:
             # delete row from `tagged_objects`
             session.query(TaggedObject).filter(
                 TaggedObject.object_type == cls.object_type,
@@ -230,8 +220,6 @@ class ObjectUpdater:
             ).delete()
 
             session.commit()
-        finally:
-            session.close()
 
 
 class ChartUpdater(ObjectUpdater):
@@ -271,8 +259,7 @@ class FavStarUpdater:
     def after_insert(
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
-        session = Session(bind=connection)
-        try:
+        with Session(bind=connection) as session:
             name = f"favorited_by:{target.user_id}"
             tag = get_tag(name, session, TagType.favorited_by)
             tagged_object = TaggedObject(
@@ -282,15 +269,12 @@ class FavStarUpdater:
             )
             session.add(tagged_object)
             session.commit()
-        finally:
-            session.close()
 
     @classmethod
     def after_delete(
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
-        session = Session(bind=connection)
-        try:
+        with Session(bind=connection) as session:
             name = f"favorited_by:{target.user_id}"
             query = (
                 session.query(TaggedObject.id)
@@ -307,5 +291,3 @@ class FavStarUpdater:
             )
 
             session.commit()
-        finally:
-            session.close()
diff --git a/tests/integration_tests/key_value/commands/create_test.py 
b/tests/integration_tests/key_value/commands/create_test.py
index a2ee3d13ae..c7ba076b5f 100644
--- a/tests/integration_tests/key_value/commands/create_test.py
+++ b/tests/integration_tests/key_value/commands/create_test.py
@@ -46,9 +46,7 @@ def test_create_id_entry(app_context: AppContext, admin: 
User) -> None:
             value=JSON_VALUE,
             codec=JSON_CODEC,
         ).run()
-        entry = (
-            
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
-        )
+        entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
         assert json.loads(entry.value) == JSON_VALUE
         assert entry.created_by_fk == admin.id
         db.session.delete(entry)
@@ -63,9 +61,7 @@ def test_create_uuid_entry(app_context: AppContext, admin: 
User) -> None:
         key = CreateKeyValueCommand(
             resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
         ).run()
-    entry = (
-        
db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one()
-    )
+    entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one()
     assert json.loads(entry.value) == JSON_VALUE
     assert entry.created_by_fk == admin.id
     db.session.delete(entry)
@@ -93,9 +89,7 @@ def test_create_pickle_entry(app_context: AppContext, admin: 
User) -> None:
             value=PICKLE_VALUE,
             codec=PICKLE_CODEC,
         ).run()
-        entry = (
-            
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
-        )
+        entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
         assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE)
         assert entry.created_by_fk == admin.id
         db.session.delete(entry)
diff --git a/tests/integration_tests/key_value/commands/update_test.py 
b/tests/integration_tests/key_value/commands/update_test.py
index 2c0fc3e31d..816a6f857a 100644
--- a/tests/integration_tests/key_value/commands/update_test.py
+++ b/tests/integration_tests/key_value/commands/update_test.py
@@ -57,7 +57,7 @@ def test_update_id_entry(
         ).run()
     assert key is not None
     assert key.id == ID_KEY
-    entry = 
db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one()
+    entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one()
     assert json.loads(entry.value) == NEW_VALUE
     assert entry.changed_by_fk == admin.id
 
@@ -79,9 +79,7 @@ def test_update_uuid_entry(
         ).run()
     assert key is not None
     assert key.uuid == UUID_KEY
-    entry = (
-        
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
-    )
+    entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
     assert json.loads(entry.value) == NEW_VALUE
     assert entry.changed_by_fk == admin.id
 
diff --git a/tests/integration_tests/key_value/commands/upsert_test.py 
b/tests/integration_tests/key_value/commands/upsert_test.py
index c26b66d02e..9b094ef65e 100644
--- a/tests/integration_tests/key_value/commands/upsert_test.py
+++ b/tests/integration_tests/key_value/commands/upsert_test.py
@@ -57,9 +57,7 @@ def test_upsert_id_entry(
         ).run()
     assert key is not None
     assert key.id == ID_KEY
-    entry = (
-        
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
-    )
+    entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one()
     assert json.loads(entry.value) == NEW_VALUE
     assert entry.changed_by_fk == admin.id
 
@@ -81,9 +79,7 @@ def test_upsert_uuid_entry(
         ).run()
     assert key is not None
     assert key.uuid == UUID_KEY
-    entry = (
-        
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
-    )
+    entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
     assert json.loads(entry.value) == NEW_VALUE
     assert entry.changed_by_fk == admin.id
 

Reply via email to