This is an automated email from the ASF dual-hosted git repository. villebro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new 550e78f feat: Add multiple table filters for Row Level Security (#9751) 550e78f is described below commit 550e78ff7c02491717960d39ef0bb861aba0f977 Author: Aliaksei Kushniarevich <axel...@gmail.com> AuthorDate: Mon Jun 22 12:51:08 2020 +0300 feat: Add multiple table filters for Row Level Security (#9751) * Add multiple table filters for Row Level Security * Set ENABLE_ROW_LEVEL_SECURITY back to False (default) * Merge DB migrations * Drop table_id column and foreign key on PostgreSQL, MySQL, SQLite * Support db records migration also * Support downgrading from the new-fashioned formatted records * Straighten up migrations * Update migration's down_revision to comply master branch --- superset/connectors/sqla/models.py | 13 ++- superset/connectors/sqla/views.py | 12 +- ...57699a813e_add_tables_relation_to_row_level_.py | 124 +++++++++++++++++++++ superset/security/manager.py | 8 +- tests/security_tests.py | 39 +++++-- 5 files changed, 178 insertions(+), 18 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 1f1fe29..3e95d63 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1332,6 +1332,14 @@ RLSFilterRoles = Table( Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")), ) +RLSFilterTables = Table( + "rls_filter_tables", + metadata, + Column("id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey("tables.id")), + Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")), +) + class RowLevelSecurityFilter(Model, AuditMixinNullable): """ @@ -1345,7 +1353,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): secondary=RLSFilterRoles, backref="row_level_security_filters", ) + tables = relationship( + SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters" + ) - table_id = Column(Integer, ForeignKey("tables.id"), nullable=False) - table = relationship(SqlaTable, backref="row_level_security_filters") clause = Column(Text, nullable=False) diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 2a612a9..39a8984 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -236,15 +236,15 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin): add_title = _("Add Row level security filter") edit_title = _("Edit Row level security filter") - list_columns = ["table.table_name", "roles", "clause", "creator", "modified"] - order_columns = ["table.table_name", "clause", "modified"] - edit_columns = ["table", "roles", "clause"] + list_columns = ["tables", "roles", "clause", "creator", "modified"] + order_columns = ["tables", "clause", "modified"] + edit_columns = ["tables", "roles", "clause"] show_columns = edit_columns - search_columns = ("table", "roles", "clause") + search_columns = ("tables", "roles", "clause") add_columns = edit_columns base_order = ("changed_on", "desc") description_columns = { - "table": _("This is the table this filter will be applied to."), + "tables": _("These are the tables this filter will be applied to."), "roles": _("These are the roles this filter will be applied to."), "clause": _( "This is the condition that will be added to the WHERE clause. " @@ -252,7 +252,7 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin): ), } label_columns = { - "table": _("Table"), + "tables": _("Tables"), "roles": _("Roles"), "clause": _("Clause"), "creator": _("Creator"), diff --git a/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py b/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py new file mode 100644 index 0000000..1ed3337 --- /dev/null +++ b/superset/migrations/versions/e557699a813e_add_tables_relation_to_row_level_.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_tables_relation_to_row_level_security + +Revision ID: e557699a813e +Revises: 743a117f0d98 +Create Date: 2020-04-24 10:46:24.119363 + +""" + +# revision identifiers, used by Alembic. +revision = "e557699a813e" +down_revision = "743a117f0d98" + +import sqlalchemy as sa +from alembic import op + +from superset.utils.core import generic_find_fk_constraint_name + + +def upgrade(): + bind = op.get_bind() + metadata = sa.MetaData(bind=bind) + insp = sa.engine.reflection.Inspector.from_engine(bind) + + rls_filter_tables = op.create_table( + "rls_filter_tables", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("table_id", sa.Integer(), nullable=True), + sa.Column("rls_filter_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["rls_filter_id"], ["row_level_security_filters.id"]), + sa.ForeignKeyConstraint(["table_id"], ["tables.id"]), + sa.PrimaryKeyConstraint("id"), + ) + + rlsf = sa.Table("row_level_security_filters", metadata, autoload=True) + filter_ids = sa.select([rlsf.c.id, rlsf.c.table_id]) + + for row in bind.execute(filter_ids): + move_table_id = rls_filter_tables.insert().values( + rls_filter_id=row["id"], table_id=row["table_id"] + ) + bind.execute(move_table_id) + + with op.batch_alter_table("row_level_security_filters") as batch_op: + fk_constraint_name = generic_find_fk_constraint_name( + "row_level_security_filters", {"id"}, "tables", insp + ) + if fk_constraint_name: + batch_op.drop_constraint(fk_constraint_name, type_="foreignkey") + batch_op.drop_column("table_id") + + +def downgrade(): + bind = op.get_bind() + metadata = sa.MetaData(bind=bind) + + op.add_column( + "row_level_security_filters", + sa.Column( + "table_id", + sa.INTEGER(), + sa.ForeignKey("tables.id"), + autoincrement=False, + nullable=True, + ), + ) + + rlsf = sa.Table("row_level_security_filters", metadata, autoload=True) + rls_filter_tables = sa.Table("rls_filter_tables", metadata, autoload=True) + rls_filter_roles = sa.Table("rls_filter_roles", metadata, autoload=True) + + filter_tables = sa.select([rls_filter_tables.c.rls_filter_id]).group_by( + rls_filter_tables.c.rls_filter_id + ) + + for row in bind.execute(filter_tables): + filters_copy_ids = [] + filter_query = rlsf.select().where(rlsf.c.id == row["rls_filter_id"]) + filter_params = dict(bind.execute(filter_query).fetchone()) + origin_id = filter_params.pop("id", None) + table_ids = bind.execute( + sa.select([rls_filter_tables.c.table_id]).where( + rls_filter_tables.c.rls_filter_id == row["rls_filter_id"] + ) + ).fetchall() + filter_params["table_id"] = table_ids.pop(0)[0] + move_table_id = ( + rlsf.update().where(rlsf.c.id == origin_id).values(filter_params) + ) + bind.execute(move_table_id) + for table_id in table_ids: + filter_params["table_id"] = table_id[0] + copy_filter = rlsf.insert().values(filter_params) + copy_id = bind.execute(copy_filter).inserted_primary_key[0] + filters_copy_ids.append(copy_id) + + roles_query = rls_filter_roles.select().where( + rls_filter_roles.c.rls_filter_id == origin_id + ) + for role in bind.execute(roles_query): + for copy_id in filters_copy_ids: + role_filter = rls_filter_roles.insert().values( + role_id=role["role_id"], rls_filter_id=copy_id + ) + bind.execute(role_filter) + filters_copy_ids.clear() + + op.alter_column("row_level_security_filters", "table_id", nullable=False) + op.drop_table("rls_filter_tables") diff --git a/superset/security/manager.py b/superset/security/manager.py index 0107115..88d7493 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -904,6 +904,7 @@ class SupersetSecurityManager(SecurityManager): from superset import db from superset.connectors.sqla.models import ( RLSFilterRoles, + RLSFilterTables, RowLevelSecurityFilter, ) @@ -917,11 +918,16 @@ class SupersetSecurityManager(SecurityManager): .filter(RLSFilterRoles.c.role_id.in_(user_roles)) .subquery() ) + filter_tables = ( + db.session.query(RLSFilterTables.c.rls_filter_id) + .filter(RLSFilterTables.c.table_id == table.id) + .subquery() + ) query = ( db.session.query( RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause ) - .filter(RowLevelSecurityFilter.table_id == table.id) + .filter(RowLevelSecurityFilter.id.in_(filter_tables)) .filter(RowLevelSecurityFilter.id.in_(filter_roles)) ) return query.all() diff --git a/tests/security_tests.py b/tests/security_tests.py index 92fccab..9426036 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -830,10 +830,12 @@ class RowLevelSecurityTests(SupersetTestCase): # Create the RowLevelSecurityFilter self.rls_entry = RowLevelSecurityFilter() - self.rls_entry.table = ( - session.query(SqlaTable).filter_by(table_name="birth_names").first() + self.rls_entry.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) + .all() ) - self.rls_entry.clause = "gender = 'boy'" + self.rls_entry.clause = "value > 1" self.rls_entry.roles.append( security_manager.find_role("Gamma") ) # db.session.query(Role).filter_by(name="Gamma").first()) @@ -852,36 +854,55 @@ class RowLevelSecurityTests(SupersetTestCase): g.user = self.get_user( username="alpha" ) # self.login() doesn't actually set the user - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table_by_name("energy_usage") query_obj = dict( groupby=[], metrics=[], filter=[], is_timeseries=False, - columns=["name"], + columns=["value"], granularity=None, from_dttm=None, to_dttm=None, extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertIn("gender = 'boy'", sql) + self.assertIn("value > 1", sql) def test_rls_filter_doesnt_alter_query(self): g.user = self.get_user( username="admin" ) # self.login() doesn't actually set the user - tbl = self.get_table_by_name("birth_names") + tbl = self.get_table_by_name("energy_usage") + query_obj = dict( + groupby=[], + metrics=[], + filter=[], + is_timeseries=False, + columns=["value"], + granularity=None, + from_dttm=None, + to_dttm=None, + extras={}, + ) + sql = tbl.get_query_str(query_obj) + self.assertNotIn("value > 1", sql) + + def test_multiple_table_filter_alters_another_tables_query(self): + g.user = self.get_user( + username="alpha" + ) # self.login() doesn't actually set the user + tbl = self.get_table_by_name("unicode_test") query_obj = dict( groupby=[], metrics=[], filter=[], is_timeseries=False, - columns=["name"], + columns=["value"], granularity=None, from_dttm=None, to_dttm=None, extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertNotIn("gender = 'boy'", sql) + self.assertIn("value > 1", sql)