ashb commented on a change in pull request #15042:
URL: https://github.com/apache/airflow/pull/15042#discussion_r616731349



##########
File path: airflow/models/auth.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pendulum import from_timestamp
+from sqlalchemy import Column, DateTime, Integer, String
+
+from airflow.models.base import Base
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import provide_session
+
+
+class TokenBlockList(Base, LoggingMixin):
+    """
+    A model for recording blocked token
+
+    :param jti: The token jti(JWT ID)
+    :type jti: str
+    :param expiry_date: When the token would expire
+    :type expiry_date: DateTime
+    """
+
+    __tablename__ = 'token_blocklist'
+    id = Column(Integer, primary_key=True)
+    jti = Column(String(50), unique=True, nullable=False)
+    expiry_date = Column(DateTime(), nullable=False, index=True)
+
+    def __init__(self, jti: str, expiry_date: DateTime):
+        super().__init__()
+        self.jti = jti
+        self.expiry_date = expiry_date
+
+    @classmethod
+    @provide_session
+    def get_token(cls, token, session=None):
+        """Get a token"""
+        tkn = session.query(cls).filter(cls.jti == token).first()
+        return tkn

Review comment:
       ```suggestion
           return session.query(cls).filter(cls.jti == token).one_or_none()
   ```

##########
File path: airflow/www/security.py
##########
@@ -18,16 +18,23 @@
 #
 
 import warnings
+from datetime import timedelta
 from typing import Dict, Optional, Sequence, Set, Tuple
 
-from flask import current_app, g
+import jwt
+from flask import current_app, g, request
+from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, 
AUTH_REMOTE_USER
 from flask_appbuilder.security.sqla import models as sqla_models
 from flask_appbuilder.security.sqla.manager import SecurityManager
 from flask_appbuilder.security.sqla.models import PermissionView, Role, User
+from flask_jwt_extended import JWTManager, create_access_token, 
create_refresh_token
+from flask_login import login_user
 from sqlalchemy import or_
 from sqlalchemy.orm import joinedload
 
 from airflow import models
+from airflow.api_connexion.exceptions import Unauthenticated
+from airflow.api_connexion.schemas.auth_schema import auth_schema

Review comment:
       Hmmmm, having this import code from api_connextion module is a bit of a 
"code smell".
   
   @jedcunningham @jhtimmins Do you have any thoughts here? I know you were 
discussing somewhere or other if we needed our own Airflow security layer 
separate from FAB.

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")
+        state = jwt.encode(
+            request.args.to_dict(flat=False),
+            app.config["SECRET_KEY"],
+            algorithm="HS256",
+        )
+        auth_provider = self.oauth_remotes[provider]
+        try:
+
+            if provider == "twitter":
+                redirect_uri = redirect_url + f"&state={state}"
+                auth_data = 
auth_provider.create_authorization_url(redirect_uri=redirect_uri)
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_uri, **auth_data)
+                return dict(auth_url=auth_data['url'])
+            else:
+                state = state.decode("ascii") if isinstance(state, bytes) else 
state
+                auth_data = auth_provider.create_authorization_url(
+                    redirect_uri=redirect_url,
+                    state=state,
+                )
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_url, **auth_data)
+                return dict(auth_url=auth_data['url'])
+        except Exception as err:  # pylint: disable=broad-except
+            raise Unauthenticated(detail=str(err))
+
+    def oauth_login_user(self, app, provider, state):
+        """Oauth login"""
+        resp = self.oauth_remotes[provider].authorize_access_token()
+        if resp is None:
+            raise Unauthenticated(detail="You denied the request to sign in")
+        # Verify state
+        try:
+            jwt.decode(
+                state,
+                app.config["SECRET_KEY"],
+                algorithms=["HS256"],
+            )
+        except jwt.InvalidTokenError:
+            raise Unauthenticated(detail="State signature is not valid!")
+        # Retrieves specific user info from the provider
+        try:
+            userinfo = self.oauth_user_info(provider, resp)
+        except Exception:  # pylint: disable=broad-except
+            user = None
+        else:
+            user = self.auth_user_oauth(userinfo)
+        if user is None:
+            raise Unauthenticated(detail="Invalid login")
+        login_user(user)
+        return user
+
+    def login_remote_user(self, username):
+        """Login user using remote auth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_REMOTE_USER:
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = self.auth_user_remote_user(username)
+        if user is None:
+            raise Unauthenticated(detail="Invalid login")
+        login_user(user)
+        return user
+
+    def create_jwt_manager(self, app) -> JWTManager:
+        """JWT Manager"""
+        jwt_manager = JWTManager()
+        app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
+        app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)

Review comment:
       These shouldn't be hard-codeded and need to come from config values.

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")
+        state = jwt.encode(
+            request.args.to_dict(flat=False),
+            app.config["SECRET_KEY"],
+            algorithm="HS256",
+        )
+        auth_provider = self.oauth_remotes[provider]
+        try:
+
+            if provider == "twitter":
+                redirect_uri = redirect_url + f"&state={state}"
+                auth_data = 
auth_provider.create_authorization_url(redirect_uri=redirect_uri)
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_uri, **auth_data)
+                return dict(auth_url=auth_data['url'])
+            else:
+                state = state.decode("ascii") if isinstance(state, bytes) else 
state
+                auth_data = auth_provider.create_authorization_url(
+                    redirect_uri=redirect_url,
+                    state=state,
+                )
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_url, **auth_data)

Review comment:
       I'm surprised to see this much code here -- how does the FAB deal with 
it when configured to use OAUTH for login?

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")
+        state = jwt.encode(
+            request.args.to_dict(flat=False),
+            app.config["SECRET_KEY"],
+            algorithm="HS256",
+        )
+        auth_provider = self.oauth_remotes[provider]
+        try:
+
+            if provider == "twitter":
+                redirect_uri = redirect_url + f"&state={state}"
+                auth_data = 
auth_provider.create_authorization_url(redirect_uri=redirect_uri)
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_uri, **auth_data)
+                return dict(auth_url=auth_data['url'])
+            else:
+                state = state.decode("ascii") if isinstance(state, bytes) else 
state
+                auth_data = auth_provider.create_authorization_url(
+                    redirect_uri=redirect_url,
+                    state=state,
+                )
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_url, **auth_data)
+                return dict(auth_url=auth_data['url'])

Review comment:
       I think this should be a normal browser redirect handled by the server, 
and not a location sent back to the front end app.
   
   This means the the _server_ is the only thing that sees the state parameter 
and avoids some possible security snafus.

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")
+        state = jwt.encode(
+            request.args.to_dict(flat=False),
+            app.config["SECRET_KEY"],
+            algorithm="HS256",
+        )
+        auth_provider = self.oauth_remotes[provider]
+        try:
+
+            if provider == "twitter":
+                redirect_uri = redirect_url + f"&state={state}"
+                auth_data = 
auth_provider.create_authorization_url(redirect_uri=redirect_uri)
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_uri, **auth_data)
+                return dict(auth_url=auth_data['url'])
+            else:
+                state = state.decode("ascii") if isinstance(state, bytes) else 
state
+                auth_data = auth_provider.create_authorization_url(
+                    redirect_uri=redirect_url,
+                    state=state,
+                )
+                auth_provider.save_authorize_data(request, 
redirect_uri=redirect_url, **auth_data)
+                return dict(auth_url=auth_data['url'])
+        except Exception as err:  # pylint: disable=broad-except
+            raise Unauthenticated(detail=str(err))
+
+    def oauth_login_user(self, app, provider, state):
+        """Oauth login"""
+        resp = self.oauth_remotes[provider].authorize_access_token()
+        if resp is None:
+            raise Unauthenticated(detail="You denied the request to sign in")
+        # Verify state
+        try:
+            jwt.decode(
+                state,
+                app.config["SECRET_KEY"],
+                algorithms=["HS256"],
+            )
+        except jwt.InvalidTokenError:
+            raise Unauthenticated(detail="State signature is not valid!")
+        # Retrieves specific user info from the provider
+        try:
+            userinfo = self.oauth_user_info(provider, resp)
+        except Exception:  # pylint: disable=broad-except
+            user = None
+        else:
+            user = self.auth_user_oauth(userinfo)
+        if user is None:
+            raise Unauthenticated(detail="Invalid login")
+        login_user(user)
+        return user
+
+    def login_remote_user(self, username):
+        """Login user using remote auth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_REMOTE_USER:
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = self.auth_user_remote_user(username)
+        if user is None:
+            raise Unauthenticated(detail="Invalid login")
+        login_user(user)
+        return user
+
+    def create_jwt_manager(self, app) -> JWTManager:
+        """JWT Manager"""

Review comment:
       ```suggestion
           """Called by FAB for us when it wants a configured JWT manager"""
   ```

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")

Review comment:
       ```suggestion
               raise Unauthenticated(detail="Authentication type does not 
match")
   ```

##########
File path: airflow/models/auth.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pendulum import from_timestamp
+from sqlalchemy import Column, DateTime, Integer, String
+
+from airflow.models.base import Base
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import provide_session
+
+
+class TokenBlockList(Base, LoggingMixin):
+    """
+    A model for recording blocked token
+
+    :param jti: The token jti(JWT ID)
+    :type jti: str
+    :param expiry_date: When the token would expire
+    :type expiry_date: DateTime
+    """
+
+    __tablename__ = 'token_blocklist'
+    id = Column(Integer, primary_key=True)
+    jti = Column(String(50), unique=True, nullable=False)
+    expiry_date = Column(DateTime(), nullable=False, index=True)
+
+    def __init__(self, jti: str, expiry_date: DateTime):
+        super().__init__()
+        self.jti = jti
+        self.expiry_date = expiry_date
+

Review comment:
       ```suggestion
   ```
   
   This isn't needed -- the default constructor has this behaviour already

##########
File path: airflow/migrations/versions/22ab4efd5674_add_token_blocklist.py
##########
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""add-token-blocklist
+
+Revision ID: 22ab4efd5674
+Revises: a13f7613ad25
+Create Date: 2021-04-17 18:16:31.019394
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = '22ab4efd5674'
+down_revision = 'a13f7613ad25'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    """Apply Add token blocklist table"""
+    op.create_table(
+        "token_blocklist",
+        sa.Column("id", sa.Integer(), primary_key=True),
+        sa.Column("jti", sa.String(50), nullable=False, unique=True),

Review comment:
       ```suggestion
           sa.Column("jti", sa.String(50), nullable=False, primary_key=True),
   ```
   
   If we use the `jti` as the primary key then the queries get simpler -- 
there's no need for an auto-increment/integer primary key for the DB, and the 
jti is unique, so lets just use that.
   
   Doing this makes the get query easier too:
   
   ```python
       session.query(TokenBlockList).get(jti)
   ``` 

##########
File path: airflow/models/auth.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pendulum import from_timestamp
+from sqlalchemy import Column, DateTime, Integer, String
+
+from airflow.models.base import Base
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import provide_session
+
+
+class TokenBlockList(Base, LoggingMixin):
+    """
+    A model for recording blocked token
+
+    :param jti: The token jti(JWT ID)
+    :type jti: str
+    :param expiry_date: When the token would expire
+    :type expiry_date: DateTime
+    """
+
+    __tablename__ = 'token_blocklist'
+    id = Column(Integer, primary_key=True)
+    jti = Column(String(50), unique=True, nullable=False)
+    expiry_date = Column(DateTime(), nullable=False, index=True)
+
+    def __init__(self, jti: str, expiry_date: DateTime):
+        super().__init__()
+        self.jti = jti
+        self.expiry_date = expiry_date
+
+    @classmethod
+    @provide_session
+    def get_token(cls, token, session=None):
+        """Get a token"""
+        tkn = session.query(cls).filter(cls.jti == token).first()
+        return tkn
+
+    @classmethod
+    @provide_session
+    def delete_token(cls, token, session=None):
+        """Delete a token"""
+        tkn = session.query(cls).filter(cls.jti == token).first()
+        if tkn:
+            session.delete(tkn)
+            session.commit()

Review comment:
       ```suggestion
           session.query(cls).filter(cls.jti == token).delete()
   ```

##########
File path: airflow/api_connexion/endpoints/auth_endpoint.py
##########
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+
+from flask import current_app, jsonify, request, session as c_session
+from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID, 
AUTH_REMOTE_USER
+from flask_jwt_extended import (
+    create_access_token,
+    decode_token,
+    get_jwt_identity,
+    get_raw_jwt,
+    jwt_refresh_token_required,
+    jwt_required,
+)
+from flask_login import login_user
+from marshmallow import ValidationError
+
+from airflow.api_connexion.exceptions import BadRequest, Unauthenticated
+from airflow.api_connexion.schemas.auth_schema import info_schema, 
login_form_schema, token_schema
+from airflow.models.auth import TokenBlockList
+
+log = logging.getLogger(__name__)
+
+
+def get_auth_info():
+    """Get site authentication info"""
+    security_manager = current_app.appbuilder.sm
+    config = current_app.config
+    auth_type = security_manager.auth_type
+    type_mapping = {
+        AUTH_DB: "auth_db",
+        AUTH_LDAP: "auth_ldap",
+        AUTH_OID: "auth_oid",
+        AUTH_OAUTH: "auth_oauth",
+        AUTH_REMOTE_USER: "auth_remote_user",
+    }
+    oauth_providers = config.get("OAUTH_PROVIDERS", None)
+    openid_providers = config.get("OPENID_PROVIDERS", None)
+    return info_schema.dump(
+        {
+            "auth_type": type_mapping[auth_type],
+            "oauth_providers": oauth_providers,
+            "openid_providers": openid_providers,
+        }
+    )
+
+
+def auth_login():
+    """Handle DB login"""
+    body = request.json
+    try:
+        data = login_form_schema.load(body)
+    except ValidationError as err:
+        raise Unauthenticated(detail=str(err.messages))
+    security_manager = current_app.appbuilder.sm
+    user = security_manager.login_with_user_pass(data['username'], 
data['password'])
+    if not user:
+        raise Unauthenticated(detail="Invalid login")
+    login_user(user, remember=False)
+    return security_manager.create_tokens_and_dump(user)
+
+
+def auth_oauthlogin(provider, register=None, redirect_url=None):
+    """Returns OAUTH authorization url"""
+    appbuilder = current_app.appbuilder
+    if register:
+        c_session["register"] = True
+    return appbuilder.sm.oauth_authorization_url(appbuilder.app, provider, 
redirect_url)
+
+
+def authorize_oauth(provider, state):
+    """Callback to authorize Oauth."""
+    appbuilder = current_app.appbuilder
+    user = appbuilder.sm.oauth_login_user(appbuilder.app, provider, state)
+    return appbuilder.sm.create_tokens_and_dump(user)
+
+
+def auth_remoteuser():
+    """Handle remote user auth"""
+    appbuilder = current_app.appbuilder
+    username = request.environ.get("REMOTE_USER")
+    if username:
+        user = appbuilder.sm.login_remote_user(username)
+    else:
+        raise Unauthenticated(detail="Invalid login")
+    return appbuilder.sm.create_tokens_and_dump(user)
+
+
+@jwt_refresh_token_required
+def refresh_token():
+    """Refresh token"""
+    user = get_jwt_identity()
+    access_token = create_access_token(identity=user)
+    ret = {'access_token': access_token}
+    return jsonify(ret), 200
+
+
+@jwt_required
+def revoke_token():
+    """
+    An endpoint for revoking both access and refresh token.
+
+    This is intended for a case where a logged in user want to revoke
+    another user's tokens
+    """
+    resp = jsonify({"revoked": True})
+    body = request.json
+    try:
+        data = token_schema.load(body)
+    except ValidationError as err:
+        raise BadRequest(detail=str(err.messages))
+    token = decode_token(data['token'])
+    tkn = TokenBlockList.get_token(token['jti'])
+    if not tkn:
+        TokenBlockList.add_token(jti=token['jti'], expiry_delta=token['exp'])

Review comment:
       I would change you you handle this:
   
   Instead of cheacking, and only adding the row if it doesn't exist, instead 
create the row "blindly" and then catch and ignore the 
[sqlalchemy.exc.IntegrityError](https://docs.sqlalchemy.org/en/13/core/exceptions.html?highlight=integrity#sqlalchemy.exc.IntegrityError)
 error.
   
   The reason for doing it this way is that otherwise it is prone to race 
conditions -- say a user double clicks log out, and this fires of two almost 
parallel log out requests:
   
   1: logout request starts
   2: logout request starts
   1: check if token exists - it doesn't
   2: check if token exists - it still doesn't
   1: add token
   2: add token too :BOOM: IntegrityError.
   
   Since it is possible anyway, we'll have to handle it, so we should avoid the 
extra query to check if the row exists.

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security
+
+    def login_with_user_pass(self, username, password):
+        """Convenience method for user login through the API"""
+        self.is_user_logged_in()
+        if self.auth_type not in (AUTH_DB, AUTH_LDAP):
+            raise Unauthenticated(detail="Authentication type do not match")
+        user = None
+        if self.auth_type == AUTH_DB:
+            user = self.auth_user_db(username, password)
+        elif self.auth_type == AUTH_LDAP:
+            user = self.auth_user_ldap(username, password)
+        return user
+
+    def oauth_authorization_url(self, app, provider, redirect_url):
+        """Get authorization url for oauth"""
+        self.is_user_logged_in()
+        if self.auth_type != AUTH_OAUTH:
+            raise Unauthenticated(detail="Authentication type do not match")
+        state = jwt.encode(
+            request.args.to_dict(flat=False),

Review comment:
       Does this just include all request args from the user into the encode 
state output?
   
   If so that's probably not a great idea and might be open to malicious 
activity in some way. What do we _need_ here?

##########
File path: airflow/www/security.py
##########
@@ -728,3 +735,102 @@ def check_authorization(
                 return False
 
         return True
+
+    # TODO: Whether to create APISecurityManager and move api related code to 
it?
+    def is_user_logged_in(self):
+        """Raise if user already logged in"""
+        if g.user is not None and g.user.is_authenticated:
+            raise Unauthenticated(detail="Client already authenticated")  # 
For security

Review comment:
       This should be a "400 Bad Request".




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to