This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new bece2ea3e4 chore: Remove unnecessary autoflush from tagging and
key/value workflows (#26009)
bece2ea3e4 is described below
commit bece2ea3e4b9979f3f45f63aa490f499095c7078
Author: John Bodley <[email protected]>
AuthorDate: Thu Nov 16 19:04:04 2023 -0800
chore: Remove unnecessary autoflush from tagging and key/value workflows
(#26009)
---
superset/key_value/commands/delete.py | 8 +----
superset/key_value/commands/get.py | 7 +---
superset/key_value/commands/update.py | 5 +--
superset/key_value/commands/upsert.py | 5 +--
superset/tags/models.py | 38 ++++++----------------
.../key_value/commands/create_test.py | 12 ++-----
.../key_value/commands/update_test.py | 6 ++--
.../key_value/commands/upsert_test.py | 8 ++---
8 files changed, 21 insertions(+), 68 deletions(-)
diff --git a/superset/key_value/commands/delete.py
b/superset/key_value/commands/delete.py
index b3cf84be07..8b9095c09c 100644
--- a/superset/key_value/commands/delete.py
+++ b/superset/key_value/commands/delete.py
@@ -57,13 +57,7 @@ class DeleteKeyValueCommand(BaseCommand):
def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key)
- entry = (
- db.session.query(KeyValueEntry)
- .filter_by(**filter_)
- .autoflush(False)
- .first()
- )
- if entry:
+ if entry :=
db.session.query(KeyValueEntry).filter_by(**filter_).first():
db.session.delete(entry)
db.session.commit()
return True
diff --git a/superset/key_value/commands/get.py
b/superset/key_value/commands/get.py
index 9d659f3bc7..8a7a250f1c 100644
--- a/superset/key_value/commands/get.py
+++ b/superset/key_value/commands/get.py
@@ -66,12 +66,7 @@ class GetKeyValueCommand(BaseCommand):
def get(self) -> Optional[Any]:
filter_ = get_filter(self.resource, self.key)
- entry = (
- db.session.query(KeyValueEntry)
- .filter_by(**filter_)
- .autoflush(False)
- .first()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(**filter_).first()
if entry and (entry.expires_on is None or entry.expires_on >
datetime.now()):
return self.codec.decode(entry.value)
return None
diff --git a/superset/key_value/commands/update.py
b/superset/key_value/commands/update.py
index becd6d9ca8..4bcd496243 100644
--- a/superset/key_value/commands/update.py
+++ b/superset/key_value/commands/update.py
@@ -77,10 +77,7 @@ class UpdateKeyValueCommand(BaseCommand):
def update(self) -> Optional[Key]:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
- db.session.query(KeyValueEntry)
- .filter_by(**filter_)
- .autoflush(False)
- .first()
+ db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
entry.value = self.codec.encode(self.value)
diff --git a/superset/key_value/commands/upsert.py
b/superset/key_value/commands/upsert.py
index c5668f1161..9a4092c002 100644
--- a/superset/key_value/commands/upsert.py
+++ b/superset/key_value/commands/upsert.py
@@ -81,10 +81,7 @@ class UpsertKeyValueCommand(BaseCommand):
def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
- db.session.query(KeyValueEntry)
- .filter_by(**filter_)
- .autoflush(False)
- .first()
+ db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
entry.value = self.codec.encode(self.value)
diff --git a/superset/tags/models.py b/superset/tags/models.py
index a469c7a33d..7a77677a36 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -20,9 +20,9 @@ import enum
from typing import TYPE_CHECKING
from flask_appbuilder import Model
-from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Table, Text
+from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table,
Text
from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship, Session, sessionmaker
+from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.mapper import Mapper
from superset import security_manager
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from superset.models.slice import Slice
from superset.models.sql_lab import Query
-Session = sessionmaker(autoflush=False)
+Session = sessionmaker()
user_favorite_tag_table = Table(
"user_favorite_tag",
@@ -111,7 +111,7 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", back_populates="objects", overlaps="tags")
-def get_tag(name: str, session: Session, type_: TagType) -> Tag:
+def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
tag_name = name.strip()
tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
if tag is None:
@@ -148,7 +148,7 @@ class ObjectUpdater:
@classmethod
def _add_owners(
cls,
- session: Session,
+ session: orm.Session,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
for owner_id in cls.get_owners_ids(target):
@@ -166,9 +166,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- session = Session(bind=connection)
-
- try:
+ with Session(bind=connection) as session:
# add `owner:` tags
cls._add_owners(session, target)
@@ -179,8 +177,6 @@ class ObjectUpdater:
)
session.add(tagged_object)
session.commit()
- finally:
- session.close()
@classmethod
def after_update(
@@ -189,9 +185,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- session = Session(bind=connection)
-
- try:
+ with Session(bind=connection) as session:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
@@ -210,8 +204,6 @@ class ObjectUpdater:
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
- finally:
- session.close()
@classmethod
def after_delete(
@@ -220,9 +212,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- session = Session(bind=connection)
-
- try:
+ with Session(bind=connection) as session:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
@@ -230,8 +220,6 @@ class ObjectUpdater:
).delete()
session.commit()
- finally:
- session.close()
class ChartUpdater(ObjectUpdater):
@@ -271,8 +259,7 @@ class FavStarUpdater:
def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
- session = Session(bind=connection)
- try:
+ with Session(bind=connection) as session:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagType.favorited_by)
tagged_object = TaggedObject(
@@ -282,15 +269,12 @@ class FavStarUpdater:
)
session.add(tagged_object)
session.commit()
- finally:
- session.close()
@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
- session = Session(bind=connection)
- try:
+ with Session(bind=connection) as session:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
@@ -307,5 +291,3 @@ class FavStarUpdater:
)
session.commit()
- finally:
- session.close()
diff --git a/tests/integration_tests/key_value/commands/create_test.py
b/tests/integration_tests/key_value/commands/create_test.py
index a2ee3d13ae..c7ba076b5f 100644
--- a/tests/integration_tests/key_value/commands/create_test.py
+++ b/tests/integration_tests/key_value/commands/create_test.py
@@ -46,9 +46,7 @@ def test_create_id_entry(app_context: AppContext, admin:
User) -> None:
value=JSON_VALUE,
codec=JSON_CODEC,
).run()
- entry = (
-
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
@@ -63,9 +61,7 @@ def test_create_uuid_entry(app_context: AppContext, admin:
User) -> None:
key = CreateKeyValueCommand(
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
).run()
- entry = (
-
db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
@@ -93,9 +89,7 @@ def test_create_pickle_entry(app_context: AppContext, admin:
User) -> None:
value=PICKLE_VALUE,
codec=PICKLE_CODEC,
).run()
- entry = (
-
db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE)
assert entry.created_by_fk == admin.id
db.session.delete(entry)
diff --git a/tests/integration_tests/key_value/commands/update_test.py
b/tests/integration_tests/key_value/commands/update_test.py
index 2c0fc3e31d..816a6f857a 100644
--- a/tests/integration_tests/key_value/commands/update_test.py
+++ b/tests/integration_tests/key_value/commands/update_test.py
@@ -57,7 +57,7 @@ def test_update_id_entry(
).run()
assert key is not None
assert key.id == ID_KEY
- entry =
db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one()
+ entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
@@ -79,9 +79,7 @@ def test_update_uuid_entry(
).run()
assert key is not None
assert key.uuid == UUID_KEY
- entry = (
-
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
diff --git a/tests/integration_tests/key_value/commands/upsert_test.py
b/tests/integration_tests/key_value/commands/upsert_test.py
index c26b66d02e..9b094ef65e 100644
--- a/tests/integration_tests/key_value/commands/upsert_test.py
+++ b/tests/integration_tests/key_value/commands/upsert_test.py
@@ -57,9 +57,7 @@ def test_upsert_id_entry(
).run()
assert key is not None
assert key.id == ID_KEY
- entry = (
-
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
@@ -81,9 +79,7 @@ def test_upsert_uuid_entry(
).run()
assert key is not None
assert key.uuid == UUID_KEY
- entry = (
-
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
- )
+ entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id