betodealmeida commented on a change in pull request #13561:
URL: https://github.com/apache/superset/pull/13561#discussion_r613634931
##########
File path: superset/utils/data.py
##########
@@ -161,5 +227,145 @@ def generate_data(columns: List[ColumnInfo], num_rows:
int) -> List[Dict[str, An
def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]:
- func = get_type_generator(column["type"])
- return [func() for _ in range(num_rows)]
+ gen = get_type_generator(column["type"])
+ return [gen() for _ in range(num_rows)]
+
+
+def add_to_model(session: Session, model: Type[Model], count: int) ->
List[Model]:
+ """
+ Add entities of a given model.
+
+ :param Model model: a Superset/FAB model
+ :param int count: how many entities to generate and insert
+ """
+ inspector = inspect(model)
+
+ # select samples to copy relationship values
+ relationships = inspector.relationships.items()
+ samples = session.query(model).limit(count).all() if relationships else []
+
+ entities: List[Model] = []
+ max_primary_key: Optional[int] = None
+ for i in range(count):
+ sample = samples[i % len(samples)] if samples else None
+ kwargs = {}
+ for column in inspector.columns.values():
+ # for primary keys, keep incrementing
+ if column.primary_key:
+ if max_primary_key is None:
+ max_primary_key = (
+ session.query(func.max(getattr(model,
column.name))).scalar()
+ or 0
+ )
+ max_primary_key += 1
+ kwargs[column.name] = max_primary_key
+
+ # if the column has a foreign key, copy the value from an existing
entity
+ elif column.foreign_keys:
+ if sample:
+ kwargs[column.name] = getattr(sample, column.name)
+ else:
+ kwargs[column.name] = get_valid_foreign_key(column)
+
+ # should be an enum but it's not
+ elif column.name == "datasource_type":
+ kwargs[column.name] = "table"
+
+ # otherwise, generate a random value based on the type
+ else:
+ kwargs[column.name] = generate_value(column)
+
+ entities.append(model(**kwargs))
+
+ session.add_all(entities)
+ return entities
+
+
+def get_valid_foreign_key(column: Column) -> Any:
+ foreign_key = list(column.foreign_keys)[0]
+ table_name, column_name = foreign_key.target_fullname.split(".", 1)
+ return db.engine.execute(f"SELECT {column_name} FROM {table_name} LIMIT
1").scalar()
+
+
+def generate_value(column: Column) -> Any:
+ if hasattr(column.type, "enums"):
+ return random.choice(column.type.enums)
+
+ json_as_string = "json" in column.name.lower() and isinstance(
+ column.type, sqlalchemy.sql.sqltypes.Text
+ )
+ type_ = sqlalchemy.sql.sqltypes.JSON() if json_as_string else column.type
+ value = get_type_generator(type_)()
+ if json_as_string:
+ value = json.dumps(value)
+ return value
+
+
+def find_models(module: ModuleType) -> List[Type[Model]]:
+ """
+ Find all models in a migration script.
+ """
+ models: List[Type[Model]] = []
+ tables = extract_modified_tables(module)
+
+ # add models defined explicitly in the migration script
+ queue = list(module.__dict__.values())
+ while queue:
+ obj = queue.pop()
+ if hasattr(obj, "__tablename__"):
+ tables.add(obj.__tablename__)
+ elif isinstance(obj, list):
+ queue.extend(obj)
+ elif isinstance(obj, dict):
+ queue.extend(obj.values())
+
+ # add implicit models
+ # pylint: disable=no-member, protected-access
+ for obj in Model._decl_class_registry.values():
+ if hasattr(obj, "__table__") and obj.__table__.fullname in tables:
+ models.append(obj)
+
+ # sort topologically so we can create entities in order and
+ # maintain relationships (eg, create a database before creating
+ # a slice)
+ sorter = TopologicalSorter()
+ for model in models:
+ inspector = inspect(model)
+ dependent_tables: List[str] = []
+ for column in inspector.columns.values():
+ for foreign_key in column.foreign_keys:
+
dependent_tables.append(foreign_key.target_fullname.split(".")[0])
+ sorter.add(model.__tablename__, *dependent_tables)
+ order = list(sorter.static_order())
+ models.sort(key=lambda model: order.index(model.__tablename__))
+
+ return models
+
+
+def import_migration_script(filepath: Path) -> ModuleType:
+ """
+ Import migration script as if it were a module.
+ """
+ spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+def extract_modified_tables(module: ModuleType) -> Set[str]:
+ """
+ Extract the tables being modified by a migration script.
+
+ This function uses a simple approach of looking at the source code of
+ the migration script looking for patterns. It could be improved by
+ actually traversing the AST.
+ """
+
+ tables: Set[str] = set()
+ for function in {"upgrade", "downgrade"}:
+ source = getsource(getattr(module, function))
+ tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source,
re.DOTALL))
+ tables.update(re.findall(r'add_column\(\s*"(\w+?)"\s*,', source,
re.DOTALL))
+ tables.update(re.findall(r'drop_column\(\s*"(\w+?)"\s*,', source,
re.DOTALL))
Review comment:
These will have to be implemented on a case by case basis, my
expectation is that people will have to tweak this script when they use it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]