This is an automated email from the ASF dual-hosted git repository.

craigrueda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2aaa4d9  chore: Migrating reports to AuthWebdriverProxy (#10567)
2aaa4d9 is described below

commit 2aaa4d92d9626abd4e7590712cb7452f80578823
Author: Craig Rueda <[email protected]>
AuthorDate: Wed Aug 12 13:28:41 2020 -0700

    chore: Migrating reports to AuthWebdriverProxy (#10567)
    
    * Migrating reports to AuthWebdriverProxy
    
    * Extracting out webdriver proxy / Adding thumbnail tests to CI
    
    * Adding license
    
    * Adding license again
    
    * Empty commit
    
    * Adding thumbnail tests to CI
    
    * Switching thumbnail test to Postgres
    
    * Linting
    
    * Adding mypy:ignore / removing thumbnail tests from CI
    
    * Putting ignore statement back
    
    * Updating docs
    
    * First cut at authprovider
    
    * First cut at authprovider mostly working - still needs more tests
    
    * Auth provider tests added
    
    * Linting
    
    * Linting again...
    
    * Linting again...
    
    * Busting CI cache
    
    * Reverting workflow change
    
    * Fixing dataclasses
    
    * Reverting back to master
    
    * linting?
    
    * Reverting installation.rst
    
    * Reverting package-lock.json
    
    * Addressing feedback
    
    * Blacking
    
    * Lazy logging strings
    
    * UPDATING.md note
---
 UPDATING.md                      |   2 +
 scripts/tests/run.sh             |  10 +--
 superset/app.py                  |   5 ++
 superset/config.py               |  18 +++++
 superset/extensions.py           |   2 +
 superset/tasks/schedules.py      |  85 +++++----------------
 superset/tasks/thumbnails.py     |   5 +-
 superset/utils/machine_auth.py   | 113 ++++++++++++++++++++++++++++
 superset/utils/screenshots.py    | 154 ++-------------------------------------
 superset/utils/webdriver.py      | 131 +++++++++++++++++++++++++++++++++
 tests/base_tests.py              |   1 +
 tests/schedules_test.py          |   4 +-
 tests/thumbnails_tests.py        |  14 +---
 tests/util/__init__.py           |  16 ++++
 tests/util/machine_auth_tests.py |  56 ++++++++++++++
 15 files changed, 376 insertions(+), 240 deletions(-)

diff --git a/UPDATING.md b/UPDATING.md
index b65c19d..a18974b 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -23,6 +23,8 @@ assists people when migrating to a new version.
 
 ## Next
 
+* [10567](https://github.com/apache/incubator-superset/pull/10567): Default 
WEBDRIVER_OPTION_ARGS are Chrome-specific. If you're using FF, should be 
`--headless` only
+
 * [10241](https://github.com/apache/incubator-superset/pull/10241): change on 
Alpha role, users started to have access to "Annotation Layers", "Css 
Templates" and "Import Dashboards".
 
 * [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook 
Prophet has been introduced as an optional dependency to add support for 
timeseries forecasting in the chart data API. To enable this feature, install 
Superset with the optional dependency `prophet` or directly `pip install 
fbprophet`.
diff --git a/scripts/tests/run.sh b/scripts/tests/run.sh
index 98206f4..95c609a 100755
--- a/scripts/tests/run.sh
+++ b/scripts/tests/run.sh
@@ -26,8 +26,8 @@ function reset_db() {
   echo --------------------
   echo Reseting test DB
   echo --------------------
-  docker-compose stop superset-tests-worker
-  RESET_DB_CMD="psql \"postgresql://superset:[email protected]:5432\" <<-EOF
+  docker-compose stop superset-tests-worker superset || true
+  RESET_DB_CMD="psql \"postgresql://${DB_USER}:${DB_PASSWORD}@127.0.0.1:5432\" 
<<-EOF
     DROP DATABASE IF EXISTS ${DB_NAME};
     CREATE DATABASE ${DB_NAME};
     \\c ${DB_NAME}
@@ -53,10 +53,6 @@ function test_init() {
   echo Superset init
   echo --------------------
   superset init
-  echo --------------------
-  echo Load examples
-  echo --------------------
-  pytest -s tests/load_examples_test.py
 }
 
 #
@@ -142,5 +138,5 @@ fi
 
 if [ $RUN_TESTS -eq 1 ]
 then
-  pytest -x -s --ignore=load_examples_test "${TEST_MODULE}"
+  pytest -x -s "${TEST_MODULE}"
 fi
diff --git a/superset/app.py b/superset/app.py
index b64ca69..11cb004 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -36,6 +36,7 @@ from superset.extensions import (
     db,
     feature_flag_manager,
     jinja_context_manager,
+    machine_auth_provider_factory,
     manifest_processor,
     migrate,
     results_backend_manager,
@@ -468,6 +469,7 @@ class SupersetAppInitializer:
         self.configure_fab()
         self.configure_url_map_converters()
         self.configure_data_sources()
+        self.configure_auth_provider()
 
         # Hook that provides administrators a handle on the Flask APP
         # after initialization
@@ -499,6 +501,9 @@ class SupersetAppInitializer:
 
         self.post_init()
 
+    def configure_auth_provider(self) -> None:
+        machine_auth_provider_factory.init_app(self.flask_app)
+
     def setup_event_logger(self) -> None:
         _event_logger["event_logger"] = get_event_logger_from_cfg_value(
             self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
diff --git a/superset/config.py b/superset/config.py
index cfef8c2..ff4796d 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -761,6 +761,11 @@ SLACK_PROXY = None
 # * Emails are sent using dry-run mode (logging only)
 SCHEDULED_EMAIL_DEBUG_MODE = False
 
+# This auth provider is used by background (offline) tasks that need to access
+# protected resources. Can be overridden by end users in order to support
+# custom auth mechanisms
+MACHINE_AUTH_PROVIDER_CLASS = "superset.utils.machine_auth.MachineAuthProvider"
+
 # Email reports - minimum time resolution (in minutes) for the crontab
 EMAIL_REPORTS_CRON_RESOLUTION = 15
 
@@ -795,9 +800,22 @@ EMAIL_REPORTS_WEBDRIVER = "firefox"
 # Window size - this will impact the rendering of the data
 WEBDRIVER_WINDOW = {"dashboard": (1600, 2000), "slice": (3000, 1200)}
 
+# An optional override to the default auth hook used to provide auth to the
+# offline webdriver
+WEBDRIVER_AUTH_FUNC = None
+
 # Any config options to be passed as-is to the webdriver
 WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {}
 
+# Additional args to be passed as arguments to the config object
+# Note: these options are Chrome-specific. For FF, these should
+# only include the "--headless" arg
+WEBDRIVER_OPTION_ARGS = [
+    "--force-device-scale-factor=2.0",
+    "--high-dpi-support=2.0",
+    "--headless",
+]
+
 # The base URL to query for accessing the user interface
 WEBDRIVER_BASEURL = "http://0.0.0.0:8080/";
 # The base URL for the email report hyperlinks.
diff --git a/superset/extensions.py b/superset/extensions.py
index 7cafef6..06d55c8 100644
--- a/superset/extensions.py
+++ b/superset/extensions.py
@@ -34,6 +34,7 @@ from werkzeug.local import LocalProxy
 
 from superset.utils.cache_manager import CacheManager
 from superset.utils.feature_flag_manager import FeatureFlagManager
+from superset.utils.machine_auth import MachineAuthProviderFactory
 
 if TYPE_CHECKING:
     from superset.jinja_context import (  # pylint: disable=unused-import
@@ -139,6 +140,7 @@ _event_logger: Dict[str, Any] = {}
 event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
 feature_flag_manager = FeatureFlagManager()
 jinja_context_manager = JinjaContextManager()
+machine_auth_provider_factory = MachineAuthProviderFactory()
 manifest_processor = UIManifestProcessor(APP_DIR)
 migrate = Migrate()
 results_backend_manager = ResultsBackendManager()
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 9ebdcfe..c38f261 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -28,7 +28,6 @@ from typing import (
     Callable,
     Dict,
     Iterator,
-    List,
     NamedTuple,
     Optional,
     Tuple,
@@ -42,17 +41,16 @@ import pandas as pd
 import simplejson as json
 from celery.app.task import Task
 from dateutil.tz import tzlocal
-from flask import current_app, render_template, Response, session, url_for
+from flask import current_app, render_template, url_for
 from flask_babel import gettext as __
-from flask_login import login_user
 from retry.api import retry_call
 from selenium.common.exceptions import WebDriverException
 from selenium.webdriver import chrome, firefox
+from selenium.webdriver.remote.webdriver import WebDriver
 from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
-from werkzeug.http import parse_cookie
 
 from superset import app, db, security_manager, thumbnail_cache
-from superset.extensions import celery_app
+from superset.extensions import celery_app, machine_auth_provider_factory
 from superset.models.alerts import Alert, AlertLog
 from superset.models.core import Database
 from superset.models.dashboard import Dashboard
@@ -66,7 +64,7 @@ from superset.models.slice import Slice
 from superset.sql_parse import ParsedQuery
 from superset.tasks.slack_util import deliver_slack_msg
 from superset.utils.core import get_email_address_list, send_email_smtp
-from superset.utils.screenshots import ChartScreenshot
+from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
 from superset.utils.urls import get_url_path
 
 # pylint: disable=too-few-public-methods
@@ -74,6 +72,7 @@ from superset.utils.urls import get_url_path
 if TYPE_CHECKING:
     # pylint: disable=unused-import
     from werkzeug.datastructures import TypeConversionDict
+    from flask_appbuilder.security.sqla.models import User
 
 
 # Globals
@@ -191,27 +190,6 @@ def _generate_report_content(
     return ReportContent(body, data, images, slack_message, screenshot)
 
 
-def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
-    # Login with the user specified to get the reports
-    with app.test_request_context():
-        user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
-        login_user(user)
-
-        # A mock response object to get the cookie information from
-        response = Response()
-        app.session_interface.save_session(app, session, response)
-
-    cookies = []
-
-    # Set the cookies in the driver
-    for name, value in response.headers:
-        if name.lower() == "set-cookie":
-            cookie = parse_cookie(value)
-            cookies.append(cookie["session"])
-
-    return cookies
-
-
 def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> 
str:
     with app.test_request_context():
         base_url = (
@@ -220,44 +198,14 @@ def _get_url_path(view: str, user_friendly: bool = False, 
**kwargs: Any) -> str:
         return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs))
 
 
-def create_webdriver() -> Union[
-    chrome.webdriver.WebDriver, firefox.webdriver.WebDriver
-]:
-    # Create a webdriver for use in fetching reports
-    if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox":
-        driver_class = firefox.webdriver.WebDriver
-        options = firefox.options.Options()
-    elif config["EMAIL_REPORTS_WEBDRIVER"] == "chrome":
-        driver_class = chrome.webdriver.WebDriver
-        options = chrome.options.Options()
-
-    options.add_argument("--headless")
-
-    # Prepare args for the webdriver init
-    kwargs = dict(options=options)
-    kwargs.update(config["WEBDRIVER_CONFIGURATION"])
-
-    # Initialize the driver
-    driver = driver_class(**kwargs)
-
-    # Some webdrivers need an initial hit to the welcome URL
-    # before we set the cookie
-    welcome_url = _get_url_path("Superset.welcome")
-
-    # Hit the welcome URL and check if we were asked to login
-    driver.get(welcome_url)
-    elements = driver.find_elements_by_id("loginbox")
-
-    # This indicates that we were not prompted for a login box.
-    if not elements:
-        return driver
+def create_webdriver() -> WebDriver:
+    return WebDriverProxy(driver_type=config["EMAIL_REPORTS_WEBDRIVER"]).auth(
+        get_reports_user()
+    )
 
-    # Set the cookies in the driver
-    for cookie in _get_auth_cookies():
-        info = dict(name="session", value=cookie)
-        driver.add_cookie(info)
 
-    return driver
+def get_reports_user() -> "User":
+    return security_manager.find_user(config["EMAIL_REPORTS_USER"])
 
 
 def destroy_webdriver(
@@ -364,12 +312,15 @@ def _get_slice_data(slc: Slice, delivery_type: 
EmailDeliveryType) -> ReportConte
         "Superset.slice", slice_id=slc.id, user_friendly=True
     )
 
-    cookies = {}
-    for cookie in _get_auth_cookies():
-        cookies["session"] = cookie
+    # Login on behalf of the "reports" user in order to get cookies to deal 
with auth
+    auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
+        get_reports_user()
+    )
+    # Build something like 
"session=cool_sess.val;other-cookie=awesome_other_cookie"
+    cookie_str = ";".join([f"{key}={val}" for key, val in 
auth_cookies.items()])
 
     opener = urllib.request.build_opener()
-    opener.addheaders.append(("Cookie", f"session={cookies['session']}"))
+    opener.addheaders.append(("Cookie", cookie_str))
     response = opener.open(slice_url)
     if response.getcode() != 200:
         raise URLError(response.getcode())
diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py
index efa704e..bf7bdc5 100644
--- a/superset/tasks/thumbnails.py
+++ b/superset/tasks/thumbnails.py
@@ -18,18 +18,17 @@
 """Utility functions used across Superset"""
 
 import logging
-from typing import Optional, Tuple
+from typing import Optional
 
 from flask import current_app
 
 from superset import app, security_manager, thumbnail_cache
 from superset.extensions import celery_app
 from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
+from superset.utils.webdriver import WindowSize
 
 logger = logging.getLogger(__name__)
 
-WindowSize = Tuple[int, int]
-
 
 @celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
 def cache_chart_thumbnail(
diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py
new file mode 100644
index 0000000..3bc8afa
--- /dev/null
+++ b/superset/utils/machine_auth.py
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import importlib
+import logging
+from typing import Callable, Dict, TYPE_CHECKING
+
+from flask import current_app, Flask, request, Response, session
+from flask_login import login_user
+from selenium.webdriver.remote.webdriver import WebDriver
+from werkzeug.http import parse_cookie
+
+from superset.utils.urls import headless_url
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+    # pylint: disable=unused-import
+    from flask_appbuilder.security.sqla.models import User
+
+
+class MachineAuthProvider:
+    def __init__(
+        self, auth_webdriver_func_override: Callable[[WebDriver, "User"], 
WebDriver]
+    ):
+        # This is here in order to allow for the authenticate_webdriver func 
to be
+        # overridden via config, as opposed to the entire provider 
implementation
+        self._auth_webdriver_func_override = auth_webdriver_func_override
+
+    def authenticate_webdriver(self, driver: WebDriver, user: "User",) -> 
WebDriver:
+        """
+            Default AuthDriverFuncType type that sets a session cookie 
flask-login style
+            :return: The WebDriver passed in (fluent)
+        """
+        # Short-circuit this method if we have an override configured
+        if self._auth_webdriver_func_override:
+            return self._auth_webdriver_func_override(driver, user)
+
+        # Setting cookies requires doing a request first
+        driver.get(headless_url("/login/"))
+
+        if user:
+            cookies = self.get_auth_cookies(user)
+        elif request.cookies:
+            cookies = request.cookies
+        else:
+            cookies = {}
+
+        for cookie_name, cookie_val in cookies.items():
+            driver.add_cookie(dict(name=cookie_name, value=cookie_val))
+
+        return driver
+
+    @staticmethod
+    def get_auth_cookies(user: "User") -> Dict[str, str]:
+        # Login with the user specified to get the reports
+        with current_app.test_request_context("/login"):
+            login_user(user)
+            # A mock response object to get the cookie information from
+            response = Response()
+            current_app.session_interface.save_session(current_app, session, 
response)
+
+        cookies = {}
+
+        # Grab any "set-cookie" headers from the login response
+        for name, value in response.headers:
+            if name.lower() == "set-cookie":
+                # This yields a MultiDict, which is ordered -- something like
+                # MultiDict([('session', 'value-we-want), ('HttpOnly', ''), 
etc...
+                # Therefore, we just need to grab the first tuple and add it 
to our
+                # final dict
+                cookie = parse_cookie(value)
+                cookie_tuple = list(cookie.items())[0]
+                cookies[cookie_tuple[0]] = cookie_tuple[1]
+
+        return cookies
+
+
+class MachineAuthProviderFactory:
+    def __init__(self) -> None:
+        self._auth_provider = None
+
+    def init_app(self, app: Flask) -> None:
+        auth_provider_fqclass = app.config["MACHINE_AUTH_PROVIDER_CLASS"]
+        auth_provider_classname = auth_provider_fqclass[
+            auth_provider_fqclass.rfind(".") + 1 :
+        ]
+        auth_provider_module_name = auth_provider_fqclass[
+            0 : auth_provider_fqclass.rfind(".")
+        ]
+        auth_provider_class = getattr(
+            importlib.import_module(auth_provider_module_name), 
auth_provider_classname
+        )
+
+        self._auth_provider = 
auth_provider_class(app.config["WEBDRIVER_AUTH_FUNC"])
+
+    @property
+    def instance(self) -> MachineAuthProvider:
+        return self._auth_provider  # type: ignore
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index 68cfd9c..9ac2b80 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -15,23 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
-import time
 from io import BytesIO
-from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, 
Union
+from typing import Optional, TYPE_CHECKING, Union
 
-from flask import current_app, request, Response, session
-from flask_login import login_user
-from retry.api import retry_call
-from selenium.common.exceptions import TimeoutException, WebDriverException
-from selenium.webdriver import chrome, firefox
-from selenium.webdriver.common.by import By
-from selenium.webdriver.remote.webdriver import WebDriver
-from selenium.webdriver.support import expected_conditions as EC
-from selenium.webdriver.support.ui import WebDriverWait
-from werkzeug.http import parse_cookie
+from flask import current_app
 
 from superset.utils.hashing import md5_sha_from_dict
-from superset.utils.urls import headless_url
+from superset.utils.webdriver import WebDriverProxy, WindowSize
 
 logger = logging.getLogger(__name__)
 
@@ -45,140 +35,6 @@ if TYPE_CHECKING:
     from flask_appbuilder.security.sqla.models import User
     from flask_caching import Cache
 
-# Time in seconds, we will wait for the page to load and render
-SELENIUM_CHECK_INTERVAL = 2
-SELENIUM_RETRIES = 5
-SELENIUM_HEADSTART = 3
-
-WindowSize = Tuple[int, int]
-
-
-def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
-    # Login with the user specified to get the reports
-    with current_app.test_request_context("/login"):
-        login_user(user)
-        # A mock response object to get the cookie information from
-        response = Response()
-        current_app.session_interface.save_session(current_app, session, 
response)
-
-    cookies = []
-
-    # Set the cookies in the driver
-    for name, value in response.headers:
-        if name.lower() == "set-cookie":
-            cookie = parse_cookie(value)
-            cookies.append(cookie["session"])
-    return cookies
-
-
-def auth_driver(driver: WebDriver, user: "User") -> WebDriver:
-    """
-        Default AuthDriverFuncType type that sets a session cookie flask-login 
style
-    :return: WebDriver
-    """
-    if user:
-        # Set the cookies in the driver
-        for cookie in get_auth_cookies(user):
-            info = dict(name="session", value=cookie)
-            driver.add_cookie(info)
-    elif request.cookies:
-        cookies = request.cookies
-        for k, v in cookies.items():
-            cookie = dict(name=k, value=v)
-            driver.add_cookie(cookie)
-    return driver
-
-
-class AuthWebDriverProxy:
-    def __init__(
-        self,
-        driver_type: str,
-        window: Optional[WindowSize] = None,
-        auth_func: Optional[
-            Callable[..., Any]
-        ] = None,  # pylint: disable=bad-whitespace
-    ):
-        self._driver_type = driver_type
-        self._window: WindowSize = window or (800, 600)
-        config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", 
auth_driver)
-        self._auth_func = auth_func or config_auth_func
-
-    def create(self) -> WebDriver:
-        if self._driver_type == "firefox":
-            driver_class = firefox.webdriver.WebDriver
-            options = firefox.options.Options()
-        elif self._driver_type == "chrome":
-            driver_class = chrome.webdriver.WebDriver
-            options = chrome.options.Options()
-            arg: str = f"--window-size={self._window[0]},{self._window[1]}"
-            options.add_argument(arg)
-            # TODO: 2 lines attempting retina PPI don't seem to be working
-            options.add_argument("--force-device-scale-factor=2.0")
-            options.add_argument("--high-dpi-support=2.0")
-        else:
-            raise Exception(f"Webdriver name ({self._driver_type}) not 
supported")
-        # Prepare args for the webdriver init
-        options.add_argument("--headless")
-        kwargs: Dict[Any, Any] = dict(options=options)
-        kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
-        logger.info("Init selenium driver")
-        return driver_class(**kwargs)
-
-    def auth(self, user: "User") -> WebDriver:
-        # Setting cookies requires doing a request first
-        driver = self.create()
-        driver.get(headless_url("/login/"))
-        return self._auth_func(driver, user)
-
-    @staticmethod
-    def destroy(driver: WebDriver, tries: int = 2) -> None:
-        """Destroy a driver"""
-        # This is some very flaky code in selenium. Hence the retries
-        # and catch-all exceptions
-        try:
-            retry_call(driver.close, tries=tries)
-        except Exception:  # pylint: disable=broad-except
-            pass
-        try:
-            driver.quit()
-        except Exception:  # pylint: disable=broad-except
-            pass
-
-    def get_screenshot(
-        self,
-        url: str,
-        element_name: str,
-        user: "User",
-        retries: int = SELENIUM_RETRIES,
-    ) -> Optional[bytes]:
-        driver = self.auth(user)
-        driver.set_window_size(*self._window)
-        driver.get(url)
-        img: Optional[bytes] = None
-        logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART)
-        time.sleep(SELENIUM_HEADSTART)
-        try:
-            logger.debug("Wait for the presence of %s", element_name)
-            element = WebDriverWait(
-                driver, current_app.config["SCREENSHOT_LOCATE_WAIT"]
-            ).until(EC.presence_of_element_located((By.CLASS_NAME, 
element_name)))
-            logger.debug("Wait for .loading to be done")
-            WebDriverWait(driver, 
current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not(
-                EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
-            )
-            logger.info("Taking a PNG screenshot")
-            img = element.screenshot_as_png
-        except TimeoutException:
-            logger.error("Selenium timed out")
-        except WebDriverException as ex:
-            logger.error(ex)
-            # Some webdrivers do not support screenshots for elements.
-            # In such cases, take a screenshot of the entire page.
-            img = driver.screenshot()  # pylint: disable=no-member
-        finally:
-            self.destroy(driver, retries)
-        return img
-
 
 class BaseScreenshot:
     driver_type = current_app.config.get("EMAIL_REPORTS_WEBDRIVER", "chrome")
@@ -192,9 +48,9 @@ class BaseScreenshot:
         self.url = url
         self.screenshot: Optional[bytes] = None
 
-    def driver(self, window_size: Optional[WindowSize] = None) -> 
AuthWebDriverProxy:
+    def driver(self, window_size: Optional[WindowSize] = None) -> 
WebDriverProxy:
         window_size = window_size or self.window_size
-        return AuthWebDriverProxy(self.driver_type, window_size)
+        return WebDriverProxy(self.driver_type, window_size)
 
     def cache_key(
         self,
diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py
new file mode 100644
index 0000000..cb8527c
--- /dev/null
+++ b/superset/utils/webdriver.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import time
+from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
+
+from flask import current_app
+from retry.api import retry_call
+from selenium.common.exceptions import TimeoutException, WebDriverException
+from selenium.webdriver import chrome, firefox
+from selenium.webdriver.common.by import By
+from selenium.webdriver.remote.webdriver import WebDriver
+from selenium.webdriver.support import expected_conditions as EC
+from selenium.webdriver.support.ui import WebDriverWait
+
+from superset.extensions import machine_auth_provider_factory
+
+WindowSize = Tuple[int, int]
+logger = logging.getLogger(__name__)
+
+# Time in seconds, we will wait for the page to load and render
+SELENIUM_CHECK_INTERVAL = 2
+SELENIUM_RETRIES = 5
+SELENIUM_HEADSTART = 3
+
+
+if TYPE_CHECKING:
+    # pylint: disable=unused-import
+    from flask_appbuilder.security.sqla.models import User
+
+
+class WebDriverProxy:
+    def __init__(
+        self, driver_type: str, window: Optional[WindowSize] = None,
+    ):
+        self._driver_type = driver_type
+        self._window: WindowSize = window or (800, 600)
+        self._screenshot_locate_wait = 
current_app.config["SCREENSHOT_LOCATE_WAIT"]
+        self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"]
+
+    def create(self) -> WebDriver:
+        if self._driver_type == "firefox":
+            driver_class = firefox.webdriver.WebDriver
+            options = firefox.options.Options()
+        elif self._driver_type == "chrome":
+            driver_class = chrome.webdriver.WebDriver
+            options = chrome.options.Options()
+            
options.add_argument(f"--window-size={self._window[0]},{self._window[1]}")
+        else:
+            raise Exception(f"Webdriver name ({self._driver_type}) not 
supported")
+        # Prepare args for the webdriver init
+
+        # Add additional configured options
+        for arg in current_app.config["WEBDRIVER_OPTION_ARGS"]:
+            options.add_argument(arg)
+
+        kwargs: Dict[Any, Any] = dict(options=options)
+        kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
+        logger.info("Init selenium driver")
+
+        return driver_class(**kwargs)
+
+    def auth(self, user: "User") -> WebDriver:
+        driver = self.create()
+        return machine_auth_provider_factory.instance.authenticate_webdriver(
+            driver, user
+        )
+
+    @staticmethod
+    def destroy(driver: WebDriver, tries: int = 2) -> None:
+        """Destroy a driver"""
+        # This is some very flaky code in selenium. Hence the retries
+        # and catch-all exceptions
+        try:
+            retry_call(driver.close, tries=tries)
+        except Exception:  # pylint: disable=broad-except
+            pass
+        try:
+            driver.quit()
+        except Exception:  # pylint: disable=broad-except
+            pass
+
+    def get_screenshot(
+        self,
+        url: str,
+        element_name: str,
+        user: "User",
+        retries: int = SELENIUM_RETRIES,
+    ) -> Optional[bytes]:
+        driver = self.auth(user)
+        driver.set_window_size(*self._window)
+        driver.get(url)
+        img: Optional[bytes] = None
+        logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART)
+        time.sleep(SELENIUM_HEADSTART)
+        try:
+            logger.debug("Wait for the presence of %s", element_name)
+            element = WebDriverWait(driver, 
self._screenshot_locate_wait).until(
+                EC.presence_of_element_located((By.CLASS_NAME, element_name))
+            )
+            logger.debug("Wait for .loading to be done")
+            WebDriverWait(driver, self._screenshot_load_wait).until_not(
+                EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
+            )
+            logger.info("Taking a PNG screenshot or url %s", url)
+            img = element.screenshot_as_png
+        except TimeoutException:
+            logger.error("Selenium timed out requesting url %s", url)
+        except WebDriverException as ex:
+            logger.error(ex)
+            # Some webdrivers do not support screenshots for elements.
+            # In such cases, take a screenshot of the entire page.
+            img = driver.screenshot()  # pylint: disable=no-member
+        finally:
+            self.destroy(driver, retries)
+        return img
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 8f708a5..8448e08 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -100,6 +100,7 @@ class SupersetTestCase(TestCase):
             assert user_to_create
         user_to_create.roles = [security_manager.find_role(r) for r in roles]
         db.session.commit()
+        return user_to_create
 
     @staticmethod
     def create_user(
diff --git a/tests/schedules_test.py b/tests/schedules_test.py
index 549a0cd..77f7070 100644
--- a/tests/schedules_test.py
+++ b/tests/schedules_test.py
@@ -40,8 +40,7 @@ from superset.tasks.schedules import (
 )
 from superset.models.slice import Slice
 from tests.base_tests import SupersetTestCase
-
-from .utils import read_fixture
+from tests.utils import read_fixture
 
 
 class TestSchedules(SupersetTestCase):
@@ -173,7 +172,6 @@ class TestSchedules(SupersetTestCase):
         mock_driver.find_elements_by_id.side_effect = [True, False]
 
         create_webdriver()
-        create_webdriver()
         mock_driver.add_cookie.assert_called_once()
 
     @patch("superset.tasks.schedules.firefox.webdriver.WebDriver")
diff --git a/tests/thumbnails_tests.py b/tests/thumbnails_tests.py
index 36126e5..fb1fd68 100644
--- a/tests/thumbnails_tests.py
+++ b/tests/thumbnails_tests.py
@@ -16,7 +16,6 @@
 # under the License.
 # from superset import db
 # from superset.models.dashboard import Dashboard
-import subprocess
 import urllib.request
 from unittest import skipUnless
 from unittest.mock import patch
@@ -24,15 +23,11 @@ from unittest.mock import patch
 from flask_testing import LiveServerTestCase
 from sqlalchemy.sql import func
 
-import tests.test_app
 from superset import db, is_feature_enabled, security_manager, thumbnail_cache
+from superset.extensions import machine_auth_provider_factory
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
-from superset.utils.screenshots import (
-    ChartScreenshot,
-    DashboardScreenshot,
-    get_auth_cookies,
-)
+from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
 from superset.utils.urls import get_url_path
 from tests.test_app import app
 
@@ -45,10 +40,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase):
 
     def url_open_auth(self, username: str, url: str):
         admin_user = security_manager.find_user(username=username)
-        cookies = {}
-        for cookie in get_auth_cookies(admin_user):
-            cookies["session"] = cookie
-
+        cookies = 
machine_auth_provider_factory.instance.get_auth_cookies(admin_user)
         opener = urllib.request.build_opener()
         opener.addheaders.append(("Cookie", f"session={cookies['session']}"))
         return opener.open(f"{self.get_server_url()}/{url}")
diff --git a/tests/util/__init__.py b/tests/util/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/util/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/util/machine_auth_tests.py b/tests/util/machine_auth_tests.py
new file mode 100644
index 0000000..1bc08e8
--- /dev/null
+++ b/tests/util/machine_auth_tests.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import call, Mock, patch
+
+from superset.extensions import machine_auth_provider_factory
+from tests.base_tests import SupersetTestCase
+
+
+class MachineAuthProviderTests(SupersetTestCase):
+    def test_get_auth_cookies(self):
+        user = self.get_user("admin")
+        auth_cookies = 
machine_auth_provider_factory.instance.get_auth_cookies(user)
+        self.assertIsNotNone(auth_cookies["session"])
+
+    @patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies")
+    def test_auth_driver_user(self, get_auth_cookies):
+        user = self.get_user("admin")
+        driver = Mock()
+        get_auth_cookies.return_value = {
+            "session": "session_val",
+            "other_cookie": "other_val",
+        }
+        machine_auth_provider_factory.instance.authenticate_webdriver(driver, 
user)
+        driver.add_cookie.assert_has_calls(
+            [
+                call({"name": "session", "value": "session_val"}),
+                call({"name": "other_cookie", "value": "other_val"}),
+            ]
+        )
+
+    @patch("superset.utils.machine_auth.request")
+    def test_auth_driver_request(self, request):
+        driver = Mock()
+        request.cookies = {"session": "session_val", "other_cookie": 
"other_val"}
+        machine_auth_provider_factory.instance.authenticate_webdriver(driver, 
None)
+        driver.add_cookie.assert_has_calls(
+            [
+                call({"name": "session", "value": "session_val"}),
+                call({"name": "other_cookie", "value": "other_val"}),
+            ]
+        )

Reply via email to