This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 550e78f  feat: Add multiple table filters for Row Level Security 
(#9751)
550e78f is described below

commit 550e78ff7c02491717960d39ef0bb861aba0f977
Author: Aliaksei Kushniarevich <axel...@gmail.com>
AuthorDate: Mon Jun 22 12:51:08 2020 +0300

    feat: Add multiple table filters for Row Level Security (#9751)
    
    * Add multiple table filters for Row Level Security
    
    * Set ENABLE_ROW_LEVEL_SECURITY back to False (default)
    
    * Merge DB migrations
    
    * Drop table_id column and foreign key on PostgreSQL, MySQL, SQLite
    
    * Support db records migration also
    
    * Support downgrading from the new-fashioned formatted records
    
    * Straighten up migrations
    
    * Update migration's down_revision to comply master branch
---
 superset/connectors/sqla/models.py                 |  13 ++-
 superset/connectors/sqla/views.py                  |  12 +-
 ...57699a813e_add_tables_relation_to_row_level_.py | 124 +++++++++++++++++++++
 superset/security/manager.py                       |   8 +-
 tests/security_tests.py                            |  39 +++++--
 5 files changed, 178 insertions(+), 18 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 1f1fe29..3e95d63 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1332,6 +1332,14 @@ RLSFilterRoles = Table(
     Column("rls_filter_id", Integer, 
ForeignKey("row_level_security_filters.id")),
 )
 
+RLSFilterTables = Table(
+    "rls_filter_tables",
+    metadata,
+    Column("id", Integer, primary_key=True),
+    Column("table_id", Integer, ForeignKey("tables.id")),
+    Column("rls_filter_id", Integer, 
ForeignKey("row_level_security_filters.id")),
+)
+
 
 class RowLevelSecurityFilter(Model, AuditMixinNullable):
     """
@@ -1345,7 +1353,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
         secondary=RLSFilterRoles,
         backref="row_level_security_filters",
     )
+    tables = relationship(
+        SqlaTable, secondary=RLSFilterTables, 
backref="row_level_security_filters"
+    )
 
-    table_id = Column(Integer, ForeignKey("tables.id"), nullable=False)
-    table = relationship(SqlaTable, backref="row_level_security_filters")
     clause = Column(Text, nullable=False)
diff --git a/superset/connectors/sqla/views.py 
b/superset/connectors/sqla/views.py
index 2a612a9..39a8984 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -236,15 +236,15 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, 
DeleteMixin):
     add_title = _("Add Row level security filter")
     edit_title = _("Edit Row level security filter")
 
-    list_columns = ["table.table_name", "roles", "clause", "creator", 
"modified"]
-    order_columns = ["table.table_name", "clause", "modified"]
-    edit_columns = ["table", "roles", "clause"]
+    list_columns = ["tables", "roles", "clause", "creator", "modified"]
+    order_columns = ["tables", "clause", "modified"]
+    edit_columns = ["tables", "roles", "clause"]
     show_columns = edit_columns
-    search_columns = ("table", "roles", "clause")
+    search_columns = ("tables", "roles", "clause")
     add_columns = edit_columns
     base_order = ("changed_on", "desc")
     description_columns = {
-        "table": _("This is the table this filter will be applied to."),
+        "tables": _("These are the tables this filter will be applied to."),
         "roles": _("These are the roles this filter will be applied to."),
         "clause": _(
             "This is the condition that will be added to the WHERE clause. "
@@ -252,7 +252,7 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, 
DeleteMixin):
         ),
     }
     label_columns = {
-        "table": _("Table"),
+        "tables": _("Tables"),
         "roles": _("Roles"),
         "clause": _("Clause"),
         "creator": _("Creator"),
diff --git 
a/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py
 
b/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py
new file mode 100644
index 0000000..1ed3337
--- /dev/null
+++ 
b/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""add_tables_relation_to_row_level_security
+
+Revision ID: e557699a813e
+Revises: 743a117f0d98
+Create Date: 2020-04-24 10:46:24.119363
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "e557699a813e"
+down_revision = "743a117f0d98"
+
+import sqlalchemy as sa
+from alembic import op
+
+from superset.utils.core import generic_find_fk_constraint_name
+
+
+def upgrade():
+    bind = op.get_bind()
+    metadata = sa.MetaData(bind=bind)
+    insp = sa.engine.reflection.Inspector.from_engine(bind)
+
+    rls_filter_tables = op.create_table(
+        "rls_filter_tables",
+        sa.Column("id", sa.Integer(), nullable=False),
+        sa.Column("table_id", sa.Integer(), nullable=True),
+        sa.Column("rls_filter_id", sa.Integer(), nullable=True),
+        sa.ForeignKeyConstraint(["rls_filter_id"], 
["row_level_security_filters.id"]),
+        sa.ForeignKeyConstraint(["table_id"], ["tables.id"]),
+        sa.PrimaryKeyConstraint("id"),
+    )
+
+    rlsf = sa.Table("row_level_security_filters", metadata, autoload=True)
+    filter_ids = sa.select([rlsf.c.id, rlsf.c.table_id])
+
+    for row in bind.execute(filter_ids):
+        move_table_id = rls_filter_tables.insert().values(
+            rls_filter_id=row["id"], table_id=row["table_id"]
+        )
+        bind.execute(move_table_id)
+
+    with op.batch_alter_table("row_level_security_filters") as batch_op:
+        fk_constraint_name = generic_find_fk_constraint_name(
+            "row_level_security_filters", {"id"}, "tables", insp
+        )
+        if fk_constraint_name:
+            batch_op.drop_constraint(fk_constraint_name, type_="foreignkey")
+        batch_op.drop_column("table_id")
+
+
+def downgrade():
+    bind = op.get_bind()
+    metadata = sa.MetaData(bind=bind)
+
+    op.add_column(
+        "row_level_security_filters",
+        sa.Column(
+            "table_id",
+            sa.INTEGER(),
+            sa.ForeignKey("tables.id"),
+            autoincrement=False,
+            nullable=True,
+        ),
+    )
+
+    rlsf = sa.Table("row_level_security_filters", metadata, autoload=True)
+    rls_filter_tables = sa.Table("rls_filter_tables", metadata, autoload=True)
+    rls_filter_roles = sa.Table("rls_filter_roles", metadata, autoload=True)
+
+    filter_tables = sa.select([rls_filter_tables.c.rls_filter_id]).group_by(
+        rls_filter_tables.c.rls_filter_id
+    )
+
+    for row in bind.execute(filter_tables):
+        filters_copy_ids = []
+        filter_query = rlsf.select().where(rlsf.c.id == row["rls_filter_id"])
+        filter_params = dict(bind.execute(filter_query).fetchone())
+        origin_id = filter_params.pop("id", None)
+        table_ids = bind.execute(
+            sa.select([rls_filter_tables.c.table_id]).where(
+                rls_filter_tables.c.rls_filter_id == row["rls_filter_id"]
+            )
+        ).fetchall()
+        filter_params["table_id"] = table_ids.pop(0)[0]
+        move_table_id = (
+            rlsf.update().where(rlsf.c.id == origin_id).values(filter_params)
+        )
+        bind.execute(move_table_id)
+        for table_id in table_ids:
+            filter_params["table_id"] = table_id[0]
+            copy_filter = rlsf.insert().values(filter_params)
+            copy_id = bind.execute(copy_filter).inserted_primary_key[0]
+            filters_copy_ids.append(copy_id)
+
+        roles_query = rls_filter_roles.select().where(
+            rls_filter_roles.c.rls_filter_id == origin_id
+        )
+        for role in bind.execute(roles_query):
+            for copy_id in filters_copy_ids:
+                role_filter = rls_filter_roles.insert().values(
+                    role_id=role["role_id"], rls_filter_id=copy_id
+                )
+                bind.execute(role_filter)
+        filters_copy_ids.clear()
+
+    op.alter_column("row_level_security_filters", "table_id", nullable=False)
+    op.drop_table("rls_filter_tables")
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 0107115..88d7493 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -904,6 +904,7 @@ class SupersetSecurityManager(SecurityManager):
             from superset import db
             from superset.connectors.sqla.models import (
                 RLSFilterRoles,
+                RLSFilterTables,
                 RowLevelSecurityFilter,
             )
 
@@ -917,11 +918,16 @@ class SupersetSecurityManager(SecurityManager):
                 .filter(RLSFilterRoles.c.role_id.in_(user_roles))
                 .subquery()
             )
+            filter_tables = (
+                db.session.query(RLSFilterTables.c.rls_filter_id)
+                .filter(RLSFilterTables.c.table_id == table.id)
+                .subquery()
+            )
             query = (
                 db.session.query(
                     RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause
                 )
-                .filter(RowLevelSecurityFilter.table_id == table.id)
+                .filter(RowLevelSecurityFilter.id.in_(filter_tables))
                 .filter(RowLevelSecurityFilter.id.in_(filter_roles))
             )
             return query.all()
diff --git a/tests/security_tests.py b/tests/security_tests.py
index 92fccab..9426036 100644
--- a/tests/security_tests.py
+++ b/tests/security_tests.py
@@ -830,10 +830,12 @@ class RowLevelSecurityTests(SupersetTestCase):
 
         # Create the RowLevelSecurityFilter
         self.rls_entry = RowLevelSecurityFilter()
-        self.rls_entry.table = (
-            
session.query(SqlaTable).filter_by(table_name="birth_names").first()
+        self.rls_entry.tables.extend(
+            session.query(SqlaTable)
+            .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
+            .all()
         )
-        self.rls_entry.clause = "gender = 'boy'"
+        self.rls_entry.clause = "value > 1"
         self.rls_entry.roles.append(
             security_manager.find_role("Gamma")
         )  # db.session.query(Role).filter_by(name="Gamma").first())
@@ -852,36 +854,55 @@ class RowLevelSecurityTests(SupersetTestCase):
         g.user = self.get_user(
             username="alpha"
         )  # self.login() doesn't actually set the user
-        tbl = self.get_table_by_name("birth_names")
+        tbl = self.get_table_by_name("energy_usage")
         query_obj = dict(
             groupby=[],
             metrics=[],
             filter=[],
             is_timeseries=False,
-            columns=["name"],
+            columns=["value"],
             granularity=None,
             from_dttm=None,
             to_dttm=None,
             extras={},
         )
         sql = tbl.get_query_str(query_obj)
-        self.assertIn("gender = 'boy'", sql)
+        self.assertIn("value > 1", sql)
 
     def test_rls_filter_doesnt_alter_query(self):
         g.user = self.get_user(
             username="admin"
         )  # self.login() doesn't actually set the user
-        tbl = self.get_table_by_name("birth_names")
+        tbl = self.get_table_by_name("energy_usage")
+        query_obj = dict(
+            groupby=[],
+            metrics=[],
+            filter=[],
+            is_timeseries=False,
+            columns=["value"],
+            granularity=None,
+            from_dttm=None,
+            to_dttm=None,
+            extras={},
+        )
+        sql = tbl.get_query_str(query_obj)
+        self.assertNotIn("value > 1", sql)
+
+    def test_multiple_table_filter_alters_another_tables_query(self):
+        g.user = self.get_user(
+            username="alpha"
+        )  # self.login() doesn't actually set the user
+        tbl = self.get_table_by_name("unicode_test")
         query_obj = dict(
             groupby=[],
             metrics=[],
             filter=[],
             is_timeseries=False,
-            columns=["name"],
+            columns=["value"],
             granularity=None,
             from_dttm=None,
             to_dttm=None,
             extras={},
         )
         sql = tbl.get_query_str(query_obj)
-        self.assertNotIn("gender = 'boy'", sql)
+        self.assertIn("value > 1", sql)

Reply via email to