This is an automated email from the ASF dual-hosted git repository.

msumit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 335fa10  Ability to test connections from UI or API (#15795)
335fa10 is described below

commit 335fa106e81ba8a446f5cda91d52527e282996be
Author: Sumit Maheshwari <msu...@users.noreply.github.com>
AuthorDate: Mon May 24 20:58:14 2021 +0530

    Ability to test connections from UI or API (#15795)
---
 .../api_connexion/endpoints/connection_endpoint.py | 29 ++++++++
 airflow/api_connexion/openapi/v1.yaml              | 39 +++++++++++
 airflow/api_connexion/schemas/connection_schema.py |  9 +++
 airflow/hooks/dbapi.py                             | 16 +++++
 airflow/models/connection.py                       | 16 +++++
 airflow/www/extensions/init_jinja_globals.py       |  1 +
 airflow/www/static/js/connection_form.js           | 77 ++++++++++++++++++++++
 airflow/www/templates/airflow/conn_create.html     |  5 ++
 airflow/www/templates/airflow/conn_edit.html       |  5 ++
 docs/apache-airflow/howto/connection.rst           | 22 +++++++
 .../endpoints/test_connection_endpoint.py          | 42 ++++++++++++
 .../schemas/test_connection_schema.py              | 18 +++++
 tests/models/test_connection.py                    | 36 ++++++++++
 tests/www/views/test_views.py                      |  5 +-
 14 files changed, 319 insertions(+), 1 deletion(-)

diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py 
b/airflow/api_connexion/endpoints/connection_endpoint.py
index 9009e68..71e4b97 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
+
 from flask import request
 from marshmallow import ValidationError
 from sqlalchemy import func
@@ -27,10 +29,13 @@ from airflow.api_connexion.schemas.connection_schema import 
(
     ConnectionCollection,
     connection_collection_schema,
     connection_schema,
+    connection_test_schema,
 )
 from airflow.models import Connection
+from airflow.secrets.environment_variables import CONN_ENV_PREFIX
 from airflow.security import permissions
 from airflow.utils.session import provide_session
+from airflow.utils.strings import get_random_string
 
 
 @security.requires_access([(permissions.ACTION_CAN_DELETE, 
permissions.RESOURCE_CONNECTION)])
@@ -129,3 +134,27 @@ def post_connection(session):
         session.commit()
         return connection_schema.dump(connection)
     raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}")
+
+
+@security.requires_access([(permissions.ACTION_CAN_CREATE, 
permissions.RESOURCE_CONNECTION)])
+def test_connection():
+    """
+    To test a connection, this method first creates an in-memory dummy conn_id 
& exports that to an
+    env var, as some hook classes tries to find out the conn from their 
__init__ method & errors out
+    if not found. It also deletes the conn id env variable after the test.
+    """
+    body = request.json
+    dummy_conn_id = get_random_string()
+    conn_env_var = f'{CONN_ENV_PREFIX}{dummy_conn_id.upper()}'
+    try:
+        data = connection_schema.load(body)
+        data['conn_id'] = dummy_conn_id
+        conn = Connection(**data)
+        os.environ[conn_env_var] = conn.get_uri()
+        status, message = conn.test_connection()
+        return connection_test_schema.dump({"status": status, "message": 
message})
+    except ValidationError as err:
+        raise BadRequest(detail=str(err.messages))
+    finally:
+        if conn_env_var in os.environ:
+            del os.environ[conn_env_var]
diff --git a/airflow/api_connexion/openapi/v1.yaml 
b/airflow/api_connexion/openapi/v1.yaml
index 72c46d6..dbe8b8d 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -369,6 +369,34 @@ paths:
         '404':
           $ref: '#/components/responses/NotFound'
 
+  /connections/test:
+    post:
+      summary: Test a connection
+      x-openapi-router-controller: 
airflow.api_connexion.endpoints.connection_endpoint
+      operationId: test_connection
+      tags: [Connection]
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/Connection'
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/ConnectionTest'
+        '400':
+          $ref: '#/components/responses/BadRequest'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+        '404':
+          $ref: '#/components/responses/NotFound'
+
   /dags:
     get:
       summary: List DAGs
@@ -1739,6 +1767,17 @@ components:
               nullable: true
               description: Other values that cannot be put into another field, 
e.g. RSA keys.
 
+    ConnectionTest:
+      description: Connection test results.
+      type: object
+      properties:
+        status:
+          type: boolean
+          description: The status of the request.
+        message:
+          type: string
+          description: The success or failure message of the request.
+
     DAG:
       type: object
       description: DAG
diff --git a/airflow/api_connexion/schemas/connection_schema.py 
b/airflow/api_connexion/schemas/connection_schema.py
index 44e3224..b1b4b97 100644
--- a/airflow/api_connexion/schemas/connection_schema.py
+++ b/airflow/api_connexion/schemas/connection_schema.py
@@ -33,6 +33,7 @@ class ConnectionCollectionItemSchema(SQLAlchemySchema):
 
     connection_id = auto_field('conn_id', required=True)
     conn_type = auto_field(required=True)
+    description = auto_field()
     host = auto_field()
     login = auto_field()
     schema = auto_field()
@@ -60,6 +61,14 @@ class ConnectionCollectionSchema(Schema):
     total_entries = fields.Int()
 
 
+class ConnectionTestSchema(Schema):
+    """connection Test Schema"""
+
+    status = fields.Boolean(required=True)
+    message = fields.String(required=True)
+
+
 connection_schema = ConnectionSchema()
 connection_collection_item_schema = ConnectionCollectionItemSchema()
 connection_collection_schema = ConnectionCollectionSchema()
+connection_test_schema = ConnectionTestSchema()
diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index 6c00320..031c221 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -350,3 +350,19 @@ class DbApiHook(BaseHook):
         :type tmp_file: str
         """
         raise NotImplementedError()
+
+    def test_connection(self):
+        """Tests the connection by executing a select 1 query"""
+        status, message = False, ''
+        try:
+            with closing(self.get_conn()) as conn:
+                with closing(conn.cursor()) as cur:
+                    cur.execute("select 1")
+                    if cur.fetchone():
+                        status = True
+                        message = 'Connection successfully tested'
+        except Exception as e:  # noqa pylint: disable=broad-except
+            status = False
+            message = str(e)
+
+        return status, message
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 9021edb..da0a6ce 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -348,6 +348,22 @@ class Connection(Base, LoggingMixin):  # pylint: 
disable=too-many-instance-attri
             self.extra_dejson,
         )
 
+    def test_connection(self):
+        """Calls out get_hook method and executes test_connection method on 
that."""
+        status, message = False, ''
+        try:
+            hook = self.get_hook()
+            if getattr(hook, 'test_connection', False):
+                status, message = hook.test_connection()
+            else:
+                message = (
+                    f"Hook {hook.__class__.__name__} doesn't implement or 
inherit test_connection method"
+                )
+        except Exception as e:  # noqa pylint: disable=broad-except
+            message = str(e)
+
+        return status, message
+
     @property
     def extra_dejson(self) -> Dict:
         """Returns the extra property by deserializing json."""
diff --git a/airflow/www/extensions/init_jinja_globals.py 
b/airflow/www/extensions/init_jinja_globals.py
index 4f1c93e..39cf096 100644
--- a/airflow/www/extensions/init_jinja_globals.py
+++ b/airflow/www/extensions/init_jinja_globals.py
@@ -66,6 +66,7 @@ def init_jinja_globals(app):
             'airflow_version': airflow_version,
             'git_version': git_version,
             'k8s_or_k8scelery_executor': IS_K8S_OR_K8SCELERY_EXECUTOR,
+            'rest_api_enabled': conf.get('api', 'auth_backend') != 
'airflow.api.auth.backend.deny_all',
         }
 
         if 'analytics_tool' in conf.getsection('webserver'):
diff --git a/airflow/www/static/js/connection_form.js 
b/airflow/www/static/js/connection_form.js
index 1fd8f62..fd5116c 100644
--- a/airflow/www/static/js/connection_form.js
+++ b/airflow/www/static/js/connection_form.js
@@ -20,6 +20,10 @@
  * Created by janomar on 23/07/15.
  */
 
+import getMetaValue from './meta_value';
+
+const restApiEnabled = getMetaValue('rest_api_enabled') === 'True';
+
 function decode(str) {
   return new DOMParser().parseFromString(str, 
"text/html").documentElement.textContent
 }
@@ -68,6 +72,18 @@ $(document).ready(function () {
   const controlsContainer = getControlsContainer();
   const connTypesToControlsMap = getConnTypesToControlsMap();
 
+  // Create a test connection button & insert it right next to the save 
(submit) button
+  const testConnBtn = $('<button id="test-connection" type="button" class="btn 
btn-sm btn-primary" ' +
+    'style="margin-left: 3px; pointer-events: all">Test\n <i class="fa 
fa-rocket"></i></button>');
+
+  if (!restApiEnabled) {
+    $(testConnBtn).addClass('disabled')
+      .attr('title', 'Airflow REST APIs have been disabled. '
+        + 'See api->auth_backend section of the Airflow configuration.');
+  }
+
+  $(testConnBtn).insertAfter($('form#model_form div.well.well-sm 
button:submit'));
+
   /**
    * Changes the connection type.
    * @param {string} connType The connection type to change to.
@@ -140,6 +156,67 @@ $(document).ready(function () {
     }
   }
 
+  /**
+   * Produces JSON stringified data from a html form data
+   *
+   * @param {string} selector Jquery from selector string.
+   * @returns {string} Form data as a JSON string
+   */
+  function getSerializedFormData(selector) {
+    const outObj = {};
+    const inArray = $(selector).serializeArray();
+
+    $.each(inArray, function () {
+      if (this.name === 'conn_id') {
+        outObj.connection_id = this.value;
+      } else if (this.value !== '' && this.name !== 'csrf_token') {
+        outObj[this.name] = this.value;
+      }
+    });
+
+    return JSON.stringify(outObj);
+  }
+
+  /**
+   * Displays the Flask style alert on UI via JS
+   *
+   * @param {boolean} status - true for success, false for error
+   * @param {string} message - The text message to show in alert box
+   */
+  function displayAlert(status, message) {
+    const alertClass = status ? 'alert-success' : 'alert-error';
+    let alertBox = $('.container .row .alert');
+    if (alertBox.length) {
+      alertBox.removeClass('alert-success').removeClass('alert-error');
+      alertBox.addClass(alertClass);
+      alertBox.text(message);
+      alertBox.show();
+    } else {
+      alertBox = $('<div class="alert ' + alertClass + '">\n' +
+                   '<button type="button" class="close" 
data-dismiss="alert">×</button>\n' + message + '</div>');
+
+      $('.container .row').prepend(alertBox).show();
+    }
+  }
+
+  // Bind click event to Test Connection button & perform an AJAX call via 
REST API
+  $('#test-connection').on('click', (e) => {
+    e.preventDefault();
+    $.ajax({
+      url: '/api/v1/connections/test',
+      type: 'post',
+      contentType: 'application/json',
+      dataType: 'json',
+      data: getSerializedFormData('form#model_form'),
+      success(data) {
+        displayAlert(data.status, data.message);
+      },
+      error(jq, err, msg) {
+        displayAlert(false, msg);
+      },
+    });
+  });
+
   const connTypeElem = document.getElementById('conn_type');
   $(connTypeElem).on('change', (e) => {
     connType = e.target.value;
diff --git a/airflow/www/templates/airflow/conn_create.html 
b/airflow/www/templates/airflow/conn_create.html
index b335dd8..c172938 100644
--- a/airflow/www/templates/airflow/conn_create.html
+++ b/airflow/www/templates/airflow/conn_create.html
@@ -19,6 +19,11 @@
 
 {% extends 'appbuilder/general/model/add.html' %}
 
+{% block head_css %}
+  {{ super() }}
+  <meta name="rest_api_enabled" content="{{ rest_api_enabled }}">
+{% endblock %}
+
 {% block tail %}
   {{ super() }}
   <script src="{{ url_for_asset('connectionForm.js') }}"></script>
diff --git a/airflow/www/templates/airflow/conn_edit.html 
b/airflow/www/templates/airflow/conn_edit.html
index c4e9ba1..be0613d 100644
--- a/airflow/www/templates/airflow/conn_edit.html
+++ b/airflow/www/templates/airflow/conn_edit.html
@@ -19,6 +19,11 @@
 
 {% extends 'appbuilder/general/model/edit.html' %}
 
+{% block head_css %}
+  {{ super() }}
+  <meta name="rest_api_enabled" content="{{ rest_api_enabled }}">
+{% endblock %}
+
 {% block tail %}
   {{ super() }}
   <script src="{{ url_for_asset(filename='connectionForm.js') }}"></script>
diff --git a/docs/apache-airflow/howto/connection.rst 
b/docs/apache-airflow/howto/connection.rst
index 2424d3f..09aa294 100644
--- a/docs/apache-airflow/howto/connection.rst
+++ b/docs/apache-airflow/howto/connection.rst
@@ -31,6 +31,8 @@ variables.
 See the :doc:`Connections Concepts </concepts/connections>` documentation for
 more information.
 
+.. _creating_connection_ui:
+
 Creating a Connection with the UI
 ---------------------------------
 
@@ -48,6 +50,8 @@ to create a new connection.
    belonging to the different connection types.
 4. Click the ``Save`` button to create the connection.
 
+.. _editing_connection_ui:
+
 Editing a Connection with the UI
 --------------------------------
 
@@ -360,6 +364,24 @@ In addition to retrieving connections from environment 
variables or the metastor
 an secrets backend to retrieve connections. For more details see 
:doc:`/security/secrets/secrets-backend/index`.
 
 
+Test Connections
+----------------
+
+Airflow Web UI & API allows to test connections. The test connection feature 
can be used from
+:ref:`create <creating_connection_ui>` or :ref:`edit <editing_connection_ui>` 
connection page, or through calling
+:doc:`Connections REST API </stable-rest-api-ref/>`.
+
+To test a connection Airflow calls out the ``test_connection`` method from the 
associated hook class and reports the
+results of it. It may happen that the connection type does not have any 
associated hook or the hook doesn't have the
+``test_connection`` method implementation, in either case the error message 
will throw the proper error message.
+
+One important point to note is that the connections will be tested from the 
webserver only, so this feature is
+subject to network egress rules setup for your webserver. Also, if webserver & 
worker machines have different libs or
+provider packages installed then the test results might differ.
+
+Last caveat is that this feature won't be available for the connections coming 
out of the secrets backends.
+
+
 Custom connection types
 -----------------------
 
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py 
b/tests/api_connexion/endpoints/test_connection_endpoint.py
index 013efc8..760e9ab 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -110,6 +110,7 @@ class TestGetConnection(TestConnectionEndpoint):
         connection_model = Connection(
             conn_id='test-connection-id',
             conn_type='mysql',
+            description='test description',
             host='mysql',
             login='login',
             schema='testschema',
@@ -127,6 +128,7 @@ class TestGetConnection(TestConnectionEndpoint):
         assert response.json == {
             "connection_id": "test-connection-id",
             "conn_type": 'mysql',
+            "description": "test description",
             "host": 'mysql',
             "login": 'login',
             'schema': 'testschema',
@@ -168,6 +170,7 @@ class TestGetConnections(TestConnectionEndpoint):
                 {
                     "connection_id": "test-connection-id-1",
                     "conn_type": 'test_type',
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -176,6 +179,7 @@ class TestGetConnections(TestConnectionEndpoint):
                 {
                     "connection_id": "test-connection-id-2",
                     "conn_type": 'test_type',
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -203,6 +207,7 @@ class TestGetConnections(TestConnectionEndpoint):
                 {
                     "connection_id": "test-connection-id-2",
                     "conn_type": 'test_type',
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -211,6 +216,7 @@ class TestGetConnections(TestConnectionEndpoint):
                 {
                     "connection_id": "test-connection-id-1",
                     "conn_type": 'test_type',
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -365,6 +371,7 @@ class TestPatchConnection(TestConnectionEndpoint):
         assert response.json == {
             "connection_id": test_connection,  # not updated
             "conn_type": 'test_type',  # Not updated
+            "description": None,  # Not updated
             "extra": None,  # Not updated
             'login': "login",  # updated
             "port": 80,  # updated
@@ -543,3 +550,38 @@ class TestPostConnection(TestConnectionEndpoint):
         )
 
         assert_401(response)
+
+
+class TestConnection(TestConnectionEndpoint):
+    def test_should_respond_200(self):
+        payload = {"connection_id": "test-connection-id", "conn_type": 
'sqlite'}
+        response = self.client.post(
+            "/api/v1/connections/test", json=payload, 
environ_overrides={'REMOTE_USER': "test"}
+        )
+        assert response.status_code == 200
+        assert response.json == {
+            'status': True,
+            'message': 'Connection successfully tested',
+        }
+
+    def test_post_should_respond_400_for_invalid_payload(self):
+        payload = {
+            "connection_id": "test-connection-id",
+        }  # conn_type missing
+        response = self.client.post(
+            "/api/v1/connections/test", json=payload, 
environ_overrides={'REMOTE_USER': "test"}
+        )
+        assert response.status_code == 400
+        assert response.json == {
+            'detail': "{'conn_type': ['Missing data for required field.']}",
+            'status': 400,
+            'title': 'Bad Request',
+            'type': EXCEPTIONS_LINK_MAP[400],
+        }
+
+    def test_should_raises_401_unauthenticated(self):
+        response = self.client.post(
+            "/api/v1/connections/test", json={"connection_id": 
"test-connection-id", "conn_type": 'test_type'}
+        )
+
+        assert_401(response)
diff --git a/tests/api_connexion/schemas/test_connection_schema.py 
b/tests/api_connexion/schemas/test_connection_schema.py
index 983a735..bab2c51 100644
--- a/tests/api_connexion/schemas/test_connection_schema.py
+++ b/tests/api_connexion/schemas/test_connection_schema.py
@@ -25,6 +25,7 @@ from airflow.api_connexion.schemas.connection_schema import (
     connection_collection_item_schema,
     connection_collection_schema,
     connection_schema,
+    connection_test_schema,
 )
 from airflow.models import Connection
 from airflow.utils.session import create_session, provide_session
@@ -56,6 +57,7 @@ class TestConnectionCollectionItemSchema(unittest.TestCase):
         assert deserialized_connection == {
             'connection_id': "mysql_default",
             'conn_type': 'mysql',
+            'description': None,
             'host': 'mysql',
             'login': 'login',
             'schema': 'testschema',
@@ -124,6 +126,7 @@ class TestConnectionCollectionSchema(unittest.TestCase):
                 {
                     "connection_id": "mysql_default_1",
                     "conn_type": "test-type",
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -132,6 +135,7 @@ class TestConnectionCollectionSchema(unittest.TestCase):
                 {
                     "connection_id": "mysql_default_2",
                     "conn_type": "test-type2",
+                    "description": None,
                     "host": None,
                     "login": None,
                     'schema': None,
@@ -169,6 +173,7 @@ class TestConnectionSchema(unittest.TestCase):
         assert deserialized_connection == {
             'connection_id': "mysql_default",
             'conn_type': 'mysql',
+            'description': None,
             'host': 'mysql',
             'login': 'login',
             'schema': 'testschema',
@@ -196,3 +201,16 @@ class TestConnectionSchema(unittest.TestCase):
             'port': 80,
             'extra': "{'key':'string'}",
         }
+
+
+class TestConnectionTestSchema(unittest.TestCase):
+    def test_response(self):
+        data = {
+            'status': True,
+            'message': 'Connection tested successful',
+        }
+        result = connection_test_schema.load(data)
+        assert result == {
+            'status': True,
+            'message': 'Connection tested successful',
+        }
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index 29eed1c..835b4c3 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -649,3 +649,39 @@ class TestConnection(unittest.TestCase):
             ]
         finally:
             session.rollback()
+
+    @mock.patch.dict(
+        'os.environ',
+        {
+            'AIRFLOW_CONN_TEST_URI': 'sqlite://',
+        },
+    )
+    def test_connection_test_success(self):
+        conn = Connection(conn_id='test_uri', conn_type='sqlite')
+        res = conn.test_connection()
+        assert res[0] is True
+        assert res[1] == 'Connection successfully tested'
+
+    @mock.patch.dict(
+        'os.environ',
+        {
+            'AIRFLOW_CONN_TEST_URI_NO_HOOK': 'fs://',
+        },
+    )
+    def test_connection_test_no_hook(self):
+        conn = Connection(conn_id='test_uri_no_hook', conn_type='fs')
+        res = conn.test_connection()
+        assert res[0] is False
+        assert res[1] == 'Unknown hook type "fs"'
+
+    @mock.patch.dict(
+        'os.environ',
+        {
+            'AIRFLOW_CONN_TEST_URI_HOOK_METHOD_MISSING': 'ftp://',
+        },
+    )
+    def test_connection_test_hook_method_missing(self):
+        conn = Connection(conn_id='test_uri_hook_method_mising', 
conn_type='ftp')
+        res = conn.test_connection()
+        assert res[0] is False
+        assert res[1] == "Hook FTPHook doesn't implement or inherit 
test_connection method"
diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py
index 738bf26..98623da 100644
--- a/tests/www/views/test_views.py
+++ b/tests/www/views/test_views.py
@@ -57,7 +57,10 @@ def test_redoc_should_render_template(capture_templates, 
admin_client):
 
     assert len(templates) == 1
     assert templates[0].name == 'airflow/redoc.html'
-    assert templates[0].local_context == {'openapi_spec_url': 
'/api/v1/openapi.yaml'}
+    assert templates[0].local_context == {
+        'openapi_spec_url': '/api/v1/openapi.yaml',
+        'rest_api_enabled': True,
+    }
 
 
 def test_plugin_should_list_on_page_with_details(admin_client):

Reply via email to