This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a0f350028 fix: add new config to allow for specific import data urls 
(#22942)
7a0f350028 is described below

commit 7a0f350028817e9980abcc1afcf5672d04af3e8b
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Mon Feb 6 08:17:08 2023 -0800

    fix: add new config to allow for specific import data urls (#22942)
---
 superset/config.py                                 |   7 ++
 superset/datasets/commands/exceptions.py           |   4 +
 superset/datasets/commands/importers/v1/utils.py   |  32 +++++-
 .../datasets/commands/importers/v1/import_test.py  | 128 ++++++++++++++++++++-
 4 files changed, 165 insertions(+), 6 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 5a64c99f77..ab23da0c29 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1406,6 +1406,13 @@ PREVENT_UNSAFE_DB_CONNECTIONS = True
 # Prevents unsafe default endpoints to be registered on datasets.
 PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True
 
+# Define a list of allowed URLs for dataset data imports (v1).
+# Simple example to only allow URLs that belong to certain domains:
+# ALLOWED_IMPORT_URL_DOMAINS = [
+#     r"^https://.+\.domain1\.com\/?.*";, r"^https://.+\.domain2\.com\/?.*";
+# ]
+DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
+
 # Path used to store SSL certificates that are generated when using custom 
certs.
 # Defaults to temporary directory.
 # Example: SSL_CERT_PATH = "/certs"
diff --git a/superset/datasets/commands/exceptions.py 
b/superset/datasets/commands/exceptions.py
index c76b7b3ad5..96ed7823b0 100644
--- a/superset/datasets/commands/exceptions.py
+++ b/superset/datasets/commands/exceptions.py
@@ -191,3 +191,7 @@ class DatasetAccessDeniedError(ForbiddenError):
 
 class DatasetDuplicateFailedError(CreateFailedError):
     message = _("Dataset could not be duplicated.")
+
+
+class DatasetForbiddenDataURI(ForbiddenError):
+    message = _("Data URI is not allowed.")
diff --git a/superset/datasets/commands/importers/v1/utils.py 
b/superset/datasets/commands/importers/v1/utils.py
index d04763c7a8..b3fb2a8041 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -29,6 +29,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound
 from sqlalchemy.sql.visitors import VisitableType
 
 from superset.connectors.sqla.models import SqlaTable
+from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
 from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
@@ -75,6 +76,28 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> 
Dict[str, VisitableType]:
     }
 
 
+def validate_data_uri(data_uri: str) -> None:
+    """
+    Validate that the data URI is configured on DATASET_IMPORT_ALLOWED_URLS
+    has a valid URL.
+
+    :param data_uri:
+    :return:
+    """
+    allowed_urls = current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"]
+    for allowed_url in allowed_urls:
+        try:
+            match = re.match(allowed_url, data_uri)
+        except re.error:
+            logger.exception(
+                "Invalid regular expression on DATASET_IMPORT_ALLOWED_URLS"
+            )
+            raise
+        if match:
+            return
+    raise DatasetForbiddenDataURI()
+
+
 def import_dataset(
     session: Session,
     config: Dict[str, Any],
@@ -139,7 +162,6 @@ def import_dataset(
         table_exists = True
 
     if data_uri and (not table_exists or force_data):
-        logger.info("Downloading data from %s", data_uri)
         load_data(data_uri, dataset, dataset.database, session)
 
     if hasattr(g, "user") and g.user:
@@ -151,6 +173,14 @@ def import_dataset(
 def load_data(
     data_uri: str, dataset: SqlaTable, database: Database, session: Session
 ) -> None:
+    """
+    Load data from a data URI into a dataset.
+
+    :raises DatasetUnAllowedDataURI: If a dataset is trying
+    to load data from a URI that is not allowed.
+    """
+    validate_data_uri(data_uri)
+    logger.info("Downloading data from %s", data_uri)
     data = request.urlopen(data_uri)  # pylint: disable=consider-using-with
     if data_uri.endswith(".gz"):
         data = gzip.open(data)
diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py 
b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
index 934712b8c9..5b52ac7f1d 100644
--- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
@@ -18,19 +18,25 @@
 
 import copy
 import json
+import re
 import uuid
 from typing import Any, Dict
+from unittest.mock import Mock, patch
 
+import pytest
+from flask import current_app
 from sqlalchemy.orm.session import Session
 
+from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
+from superset.datasets.commands.importers.v1.utils import validate_data_uri
+
 
 def test_import_dataset(session: Session) -> None:
     """
     Test importing a dataset.
     """
-    from superset.connectors.sqla.models import SqlaTable, SqlMetric, 
TableColumn
+    from superset.connectors.sqla.models import SqlaTable
     from superset.datasets.commands.importers.v1.utils import import_dataset
-    from superset.datasets.schemas import ImportV1DatasetSchema
     from superset.models.core import Database
 
     engine = session.get_bind()
@@ -340,13 +346,85 @@ def test_import_column_extra_is_string(session: Session) 
-> None:
     assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
 
 
+@patch("superset.datasets.commands.importers.v1.utils.request")
+def test_import_column_allowed_data_url(request: Mock, session: Session) -> 
None:
+    """
+    Test importing a dataset when using data key to fetch data from a URL.
+    """
+    import io
+
+    from superset.connectors.sqla.models import SqlaTable
+    from superset.datasets.commands.importers.v1.utils import import_dataset
+    from superset.datasets.schemas import ImportV1DatasetSchema
+    from superset.models.core import Database
+
+    request.urlopen.return_value = io.StringIO("col1\nvalue1\nvalue2\n")
+
+    engine = session.get_bind()
+    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
+
+    database = Database(database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    session.add(database)
+    session.flush()
+
+    dataset_uuid = uuid.uuid4()
+    yaml_config: Dict[str, Any] = {
+        "version": "1.0.0",
+        "table_name": "my_table",
+        "main_dttm_col": "ds",
+        "description": "This is the description",
+        "default_endpoint": None,
+        "offset": -8,
+        "cache_timeout": 3600,
+        "schema": None,
+        "sql": None,
+        "params": {
+            "remote_id": 64,
+            "database_name": "examples",
+            "import_time": 1606677834,
+        },
+        "template_params": None,
+        "filter_select_enabled": True,
+        "fetch_values_predicate": None,
+        "extra": None,
+        "uuid": dataset_uuid,
+        "metrics": [],
+        "columns": [
+            {
+                "column_name": "col1",
+                "verbose_name": None,
+                "is_dttm": False,
+                "is_active": True,
+                "type": "TEXT",
+                "groupby": False,
+                "filterable": False,
+                "expression": None,
+                "description": None,
+                "python_date_format": None,
+                "extra": None,
+            }
+        ],
+        "database_uuid": database.uuid,
+        "data": "https://some-external-url.com/data.csv";,
+    }
+
+    # the Marshmallow schema should convert strings to objects
+    schema = ImportV1DatasetSchema()
+    dataset_config = schema.load(yaml_config)
+    dataset_config["database_id"] = database.id
+    _ = import_dataset(session, dataset_config, force_data=True)
+    session.connection()
+    assert [("value1",), ("value2",)] == session.execute(
+        "SELECT * FROM my_table"
+    ).fetchall()
+
+
 def test_import_dataset_managed_externally(session: Session) -> None:
     """
     Test importing a dataset that is managed externally.
     """
-    from superset.connectors.sqla.models import SqlaTable, SqlMetric, 
TableColumn
+    from superset.connectors.sqla.models import SqlaTable
     from superset.datasets.commands.importers.v1.utils import import_dataset
-    from superset.datasets.schemas import ImportV1DatasetSchema
     from superset.models.core import Database
     from tests.integration_tests.fixtures.importexport import dataset_config
 
@@ -357,7 +435,6 @@ def test_import_dataset_managed_externally(session: 
Session) -> None:
     session.add(database)
     session.flush()
 
-    dataset_uuid = uuid.uuid4()
     config = copy.deepcopy(dataset_config)
     config["is_managed_externally"] = True
     config["external_url"] = "https://example.org/my_table";
@@ -366,3 +443,44 @@ def test_import_dataset_managed_externally(session: 
Session) -> None:
     sqla_table = import_dataset(session, config)
     assert sqla_table.is_managed_externally is True
     assert sqla_table.external_url == "https://example.org/my_table";
+
+
[email protected](
+    "allowed_urls, data_uri, expected, exception_class",
+    [
+        ([r".*"], "https://some-url/data.csv";, True, None),
+        (
+            [r"^https://.+\.domain1\.com\/?.*";, 
r"^https://.+\.domain2\.com\/?.*";],
+            "https://host1.domain1.com/data.csv";,
+            True,
+            None,
+        ),
+        (
+            [r"^https://.+\.domain1\.com\/?.*";, 
r"^https://.+\.domain2\.com\/?.*";],
+            "https://host2.domain1.com/data.csv";,
+            True,
+            None,
+        ),
+        (
+            [r"^https://.+\.domain1\.com\/?.*";, 
r"^https://.+\.domain2\.com\/?.*";],
+            "https://host1.domain2.com/data.csv";,
+            True,
+            None,
+        ),
+        (
+            [r"^https://.+\.domain1\.com\/?.*";, 
r"^https://.+\.domain2\.com\/?.*";],
+            "https://host1.domain3.com/data.csv";,
+            False,
+            DatasetForbiddenDataURI,
+        ),
+        ([], "https://host1.domain3.com/data.csv";, False, 
DatasetForbiddenDataURI),
+        (["*"], "https://host1.domain3.com/data.csv";, False, re.error),
+    ],
+)
+def test_validate_data_uri(allowed_urls, data_uri, expected, exception_class):
+    current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] = allowed_urls
+    if expected:
+        validate_data_uri(data_uri)
+    else:
+        with pytest.raises(exception_class):
+            validate_data_uri(data_uri)

Reply via email to