[ https://issues.apache.org/jira/browse/AIRFLOW-3458?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16725802#comment-16725802 ]
ASF GitHub Bot commented on AIRFLOW-3458: ----------------------------------------- Fokko closed pull request #4335: [AIRFLOW-3458] Move models.Connection into separate file URL: https://github.com/apache/incubator-airflow/pull/4335 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index cd414d2821..143e2b34aa 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -33,6 +33,8 @@ import argparse from builtins import input from collections import namedtuple + +from airflow.models.connection import Connection from airflow.utils.timezone import parse as parsedate import json from tabulate import tabulate @@ -55,8 +57,7 @@ from airflow.exceptions import AirflowException, AirflowWebServerTimeout from airflow.executors import GetDefaultExecutor from airflow.models import (DagModel, DagBag, TaskInstance, - DagPickle, DagRun, Variable, DagStat, - Connection, DAG) + DagPickle, DagRun, Variable, DagStat, DAG) from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import cli as cli_utils diff --git a/airflow/contrib/executors/mesos_executor.py b/airflow/contrib/executors/mesos_executor.py index 0609d71cf2..7aae91e6d4 100644 --- a/airflow/contrib/executors/mesos_executor.py +++ b/airflow/contrib/executors/mesos_executor.py @@ -80,7 +80,7 @@ def registered(self, driver, frameworkId, masterInfo): if configuration.conf.getboolean('mesos', 'CHECKPOINT') and \ configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error - from airflow.models import Connection + from airflow.models.connection import Connection # Update the Framework ID in the database. session = Session() @@ -253,7 +253,7 @@ def start(self): if configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error - from airflow.models import Connection + from airflow.models.connection import Connection # Query the database to get the ID of the Mesos Framework, if available. conn_id = FRAMEWORK_CONNID_PREFIX + framework.name diff --git a/airflow/contrib/hooks/gcp_sql_hook.py b/airflow/contrib/hooks/gcp_sql_hook.py index 1581637e0d..9872746b7b 100644 --- a/airflow/contrib/hooks/gcp_sql_hook.py +++ b/airflow/contrib/hooks/gcp_sql_hook.py @@ -34,7 +34,7 @@ import requests from googleapiclient.discovery import build -from airflow import AirflowException, LoggingMixin, models +from airflow import AirflowException, LoggingMixin from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook # Number of retries - used by googleapiclient method calls to perform retries @@ -42,7 +42,7 @@ from airflow.hooks.base_hook import BaseHook from airflow.hooks.mysql_hook import MySqlHook from airflow.hooks.postgres_hook import PostgresHook -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.utils.db import provide_session NUM_RETRIES = 5 @@ -457,8 +457,8 @@ def _download_sql_proxy_if_needed(self): @provide_session def _get_credential_parameters(self, session): - connection = session.query(models.Connection). \ - filter(models.Connection.conn_id == self.gcp_conn_id).first() + connection = session.query(Connection). \ + filter(Connection.conn_id == self.gcp_conn_id).first() session.expunge_all() if GCP_CREDENTIALS_KEY_PATH in connection.extra_dejson: credential_params = [ @@ -851,8 +851,8 @@ def delete_connection(self, session=None): decorator). """ self.log.info("Deleting connection {}".format(self.db_conn_id)) - connection = session.query(models.Connection).filter( - models.Connection.conn_id == self.db_conn_id)[0] + connection = session.query(Connection).filter( + Connection.conn_id == self.db_conn_id)[0] session.delete(connection) session.commit() diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index ef44f6469d..c1283e3fb4 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -25,7 +25,7 @@ import os import random -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.exceptions import AirflowException from airflow.utils.db import provide_session from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/models.py b/airflow/models/__init__.py similarity index 95% rename from airflow/models.py rename to airflow/models/__init__.py index 1089970b65..aa93f5cb88 100755 --- a/airflow/models.py +++ b/airflow/models/__init__.py @@ -60,7 +60,7 @@ import uuid from datetime import datetime -from urllib.parse import urlparse, quote, parse_qsl, unquote +from urllib.parse import quote from sqlalchemy import ( Boolean, Column, DateTime, Float, ForeignKey, ForeignKeyConstraint, Index, @@ -613,245 +613,6 @@ def is_superuser(self): return self.superuser -class Connection(Base, LoggingMixin): - """ - Placeholder to store information about different database instances - connection information. The idea here is that scripts use references to - database instances (conn_id) instead of hard coding hostname, logins and - passwords when using operators or hooks. - """ - __tablename__ = "connection" - - id = Column(Integer(), primary_key=True) - conn_id = Column(String(ID_LEN)) - conn_type = Column(String(500)) - host = Column(String(500)) - schema = Column(String(500)) - login = Column(String(500)) - _password = Column('password', String(5000)) - port = Column(Integer()) - is_encrypted = Column(Boolean, unique=False, default=False) - is_extra_encrypted = Column(Boolean, unique=False, default=False) - _extra = Column('extra', String(5000)) - - _types = [ - ('docker', 'Docker Registry',), - ('fs', 'File (path)'), - ('ftp', 'FTP',), - ('google_cloud_platform', 'Google Cloud Platform'), - ('hdfs', 'HDFS',), - ('http', 'HTTP',), - ('hive_cli', 'Hive Client Wrapper',), - ('hive_metastore', 'Hive Metastore Thrift',), - ('hiveserver2', 'Hive Server 2 Thrift',), - ('jdbc', 'Jdbc Connection',), - ('jenkins', 'Jenkins'), - ('mysql', 'MySQL',), - ('postgres', 'Postgres',), - ('oracle', 'Oracle',), - ('vertica', 'Vertica',), - ('presto', 'Presto',), - ('s3', 'S3',), - ('samba', 'Samba',), - ('sqlite', 'Sqlite',), - ('ssh', 'SSH',), - ('cloudant', 'IBM Cloudant',), - ('mssql', 'Microsoft SQL Server'), - ('mesos_framework-id', 'Mesos Framework ID'), - ('jira', 'JIRA',), - ('redis', 'Redis',), - ('wasb', 'Azure Blob Storage'), - ('databricks', 'Databricks',), - ('aws', 'Amazon Web Services',), - ('emr', 'Elastic MapReduce',), - ('snowflake', 'Snowflake',), - ('segment', 'Segment',), - ('azure_data_lake', 'Azure Data Lake'), - ('azure_cosmos', 'Azure CosmosDB'), - ('cassandra', 'Cassandra',), - ('qubole', 'Qubole'), - ('mongo', 'MongoDB'), - ('gcpcloudsql', 'Google Cloud SQL'), - ] - - def __init__( - self, conn_id=None, conn_type=None, - host=None, login=None, password=None, - schema=None, port=None, extra=None, - uri=None): - self.conn_id = conn_id - if uri: - self.parse_from_uri(uri) - else: - self.conn_type = conn_type - self.host = host - self.login = login - self.password = password - self.schema = schema - self.port = port - self.extra = extra - - def parse_from_uri(self, uri): - temp_uri = urlparse(uri) - hostname = temp_uri.hostname or '' - conn_type = temp_uri.scheme - if conn_type == 'postgresql': - conn_type = 'postgres' - self.conn_type = conn_type - self.host = unquote(hostname) if hostname else hostname - quoted_schema = temp_uri.path[1:] - self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema - self.login = unquote(temp_uri.username) \ - if temp_uri.username else temp_uri.username - self.password = unquote(temp_uri.password) \ - if temp_uri.password else temp_uri.password - self.port = temp_uri.port - if temp_uri.query: - self.extra = json.dumps(dict(parse_qsl(temp_uri.query))) - - def get_password(self): - if self._password and self.is_encrypted: - fernet = get_fernet() - if not fernet.is_encrypted: - raise AirflowException( - "Can't decrypt encrypted password for login={}, \ - FERNET_KEY configuration is missing".format(self.login)) - return fernet.decrypt(bytes(self._password, 'utf-8')).decode() - else: - return self._password - - def set_password(self, value): - if value: - fernet = get_fernet() - self._password = fernet.encrypt(bytes(value, 'utf-8')).decode() - self.is_encrypted = fernet.is_encrypted - - @declared_attr - def password(cls): - return synonym('_password', - descriptor=property(cls.get_password, cls.set_password)) - - def get_extra(self): - if self._extra and self.is_extra_encrypted: - fernet = get_fernet() - if not fernet.is_encrypted: - raise AirflowException( - "Can't decrypt `extra` params for login={},\ - FERNET_KEY configuration is missing".format(self.login)) - return fernet.decrypt(bytes(self._extra, 'utf-8')).decode() - else: - return self._extra - - def set_extra(self, value): - if value: - fernet = get_fernet() - self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode() - self.is_extra_encrypted = fernet.is_encrypted - else: - self._extra = value - self.is_extra_encrypted = False - - @declared_attr - def extra(cls): - return synonym('_extra', - descriptor=property(cls.get_extra, cls.set_extra)) - - def get_hook(self): - try: - if self.conn_type == 'mysql': - from airflow.hooks.mysql_hook import MySqlHook - return MySqlHook(mysql_conn_id=self.conn_id) - elif self.conn_type == 'google_cloud_platform': - from airflow.contrib.hooks.bigquery_hook import BigQueryHook - return BigQueryHook(bigquery_conn_id=self.conn_id) - elif self.conn_type == 'postgres': - from airflow.hooks.postgres_hook import PostgresHook - return PostgresHook(postgres_conn_id=self.conn_id) - elif self.conn_type == 'hive_cli': - from airflow.hooks.hive_hooks import HiveCliHook - return HiveCliHook(hive_cli_conn_id=self.conn_id) - elif self.conn_type == 'presto': - from airflow.hooks.presto_hook import PrestoHook - return PrestoHook(presto_conn_id=self.conn_id) - elif self.conn_type == 'hiveserver2': - from airflow.hooks.hive_hooks import HiveServer2Hook - return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) - elif self.conn_type == 'sqlite': - from airflow.hooks.sqlite_hook import SqliteHook - return SqliteHook(sqlite_conn_id=self.conn_id) - elif self.conn_type == 'jdbc': - from airflow.hooks.jdbc_hook import JdbcHook - return JdbcHook(jdbc_conn_id=self.conn_id) - elif self.conn_type == 'mssql': - from airflow.hooks.mssql_hook import MsSqlHook - return MsSqlHook(mssql_conn_id=self.conn_id) - elif self.conn_type == 'oracle': - from airflow.hooks.oracle_hook import OracleHook - return OracleHook(oracle_conn_id=self.conn_id) - elif self.conn_type == 'vertica': - from airflow.contrib.hooks.vertica_hook import VerticaHook - return VerticaHook(vertica_conn_id=self.conn_id) - elif self.conn_type == 'cloudant': - from airflow.contrib.hooks.cloudant_hook import CloudantHook - return CloudantHook(cloudant_conn_id=self.conn_id) - elif self.conn_type == 'jira': - from airflow.contrib.hooks.jira_hook import JiraHook - return JiraHook(jira_conn_id=self.conn_id) - elif self.conn_type == 'redis': - from airflow.contrib.hooks.redis_hook import RedisHook - return RedisHook(redis_conn_id=self.conn_id) - elif self.conn_type == 'wasb': - from airflow.contrib.hooks.wasb_hook import WasbHook - return WasbHook(wasb_conn_id=self.conn_id) - elif self.conn_type == 'docker': - from airflow.hooks.docker_hook import DockerHook - return DockerHook(docker_conn_id=self.conn_id) - elif self.conn_type == 'azure_data_lake': - from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook - return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) - elif self.conn_type == 'azure_cosmos': - from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook - return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) - elif self.conn_type == 'cassandra': - from airflow.contrib.hooks.cassandra_hook import CassandraHook - return CassandraHook(cassandra_conn_id=self.conn_id) - elif self.conn_type == 'mongo': - from airflow.contrib.hooks.mongo_hook import MongoHook - return MongoHook(conn_id=self.conn_id) - elif self.conn_type == 'gcpcloudsql': - from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook - return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) - except Exception: - pass - - def __repr__(self): - return self.conn_id - - def debug_info(self): - return ("id: {}. Host: {}, Port: {}, Schema: {}, " - "Login: {}, Password: {}, extra: {}". - format(self.conn_id, - self.host, - self.port, - self.schema, - self.login, - "XXXXXXXX" if self.password else None, - self.extra_dejson)) - - @property - def extra_dejson(self): - """Returns the extra property by deserializing json.""" - obj = {} - if self.extra: - try: - obj = json.loads(self.extra) - except Exception as e: - self.log.exception(e) - self.log.error("Failed parsing the json for conn_id %s", self.conn_id) - - return obj - - class DagPickle(Base): """ Dags can originate from different places (user repos, master repo, ...) diff --git a/airflow/models/connection.py b/airflow/models/connection.py new file mode 100644 index 0000000000..9a51a43c08 --- /dev/null +++ b/airflow/models/connection.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from builtins import bytes +from urllib.parse import urlparse, unquote, parse_qsl + +from sqlalchemy import Column, Integer, String, Boolean +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import synonym + +from airflow import LoggingMixin, AirflowException +from airflow.models import Base, ID_LEN, get_fernet + + +class Connection(Base, LoggingMixin): + """ + Placeholder to store information about different database instances + connection information. The idea here is that scripts use references to + database instances (conn_id) instead of hard coding hostname, logins and + passwords when using operators or hooks. + """ + __tablename__ = "connection" + + id = Column(Integer(), primary_key=True) + conn_id = Column(String(ID_LEN)) + conn_type = Column(String(500)) + host = Column(String(500)) + schema = Column(String(500)) + login = Column(String(500)) + _password = Column('password', String(5000)) + port = Column(Integer()) + is_encrypted = Column(Boolean, unique=False, default=False) + is_extra_encrypted = Column(Boolean, unique=False, default=False) + _extra = Column('extra', String(5000)) + + _types = [ + ('docker', 'Docker Registry',), + ('fs', 'File (path)'), + ('ftp', 'FTP',), + ('google_cloud_platform', 'Google Cloud Platform'), + ('hdfs', 'HDFS',), + ('http', 'HTTP',), + ('hive_cli', 'Hive Client Wrapper',), + ('hive_metastore', 'Hive Metastore Thrift',), + ('hiveserver2', 'Hive Server 2 Thrift',), + ('jdbc', 'Jdbc Connection',), + ('jenkins', 'Jenkins'), + ('mysql', 'MySQL',), + ('postgres', 'Postgres',), + ('oracle', 'Oracle',), + ('vertica', 'Vertica',), + ('presto', 'Presto',), + ('s3', 'S3',), + ('samba', 'Samba',), + ('sqlite', 'Sqlite',), + ('ssh', 'SSH',), + ('cloudant', 'IBM Cloudant',), + ('mssql', 'Microsoft SQL Server'), + ('mesos_framework-id', 'Mesos Framework ID'), + ('jira', 'JIRA',), + ('redis', 'Redis',), + ('wasb', 'Azure Blob Storage'), + ('databricks', 'Databricks',), + ('aws', 'Amazon Web Services',), + ('emr', 'Elastic MapReduce',), + ('snowflake', 'Snowflake',), + ('segment', 'Segment',), + ('azure_data_lake', 'Azure Data Lake'), + ('azure_cosmos', 'Azure CosmosDB'), + ('cassandra', 'Cassandra',), + ('qubole', 'Qubole'), + ('mongo', 'MongoDB'), + ('gcpcloudsql', 'Google Cloud SQL'), + ] + + def __init__( + self, conn_id=None, conn_type=None, + host=None, login=None, password=None, + schema=None, port=None, extra=None, + uri=None): + self.conn_id = conn_id + if uri: + self.parse_from_uri(uri) + else: + self.conn_type = conn_type + self.host = host + self.login = login + self.password = password + self.schema = schema + self.port = port + self.extra = extra + + def parse_from_uri(self, uri): + temp_uri = urlparse(uri) + hostname = temp_uri.hostname or '' + conn_type = temp_uri.scheme + if conn_type == 'postgresql': + conn_type = 'postgres' + self.conn_type = conn_type + self.host = unquote(hostname) if hostname else hostname + quoted_schema = temp_uri.path[1:] + self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema + self.login = unquote(temp_uri.username) \ + if temp_uri.username else temp_uri.username + self.password = unquote(temp_uri.password) \ + if temp_uri.password else temp_uri.password + self.port = temp_uri.port + if temp_uri.query: + self.extra = json.dumps(dict(parse_qsl(temp_uri.query))) + + def get_password(self): + if self._password and self.is_encrypted: + fernet = get_fernet() + if not fernet.is_encrypted: + raise AirflowException( + "Can't decrypt encrypted password for login={}, \ + FERNET_KEY configuration is missing".format(self.login)) + return fernet.decrypt(bytes(self._password, 'utf-8')).decode() + else: + return self._password + + def set_password(self, value): + if value: + fernet = get_fernet() + self._password = fernet.encrypt(bytes(value, 'utf-8')).decode() + self.is_encrypted = fernet.is_encrypted + + @declared_attr + def password(cls): + return synonym('_password', + descriptor=property(cls.get_password, cls.set_password)) + + def get_extra(self): + if self._extra and self.is_extra_encrypted: + fernet = get_fernet() + if not fernet.is_encrypted: + raise AirflowException( + "Can't decrypt `extra` params for login={},\ + FERNET_KEY configuration is missing".format(self.login)) + return fernet.decrypt(bytes(self._extra, 'utf-8')).decode() + else: + return self._extra + + def set_extra(self, value): + if value: + fernet = get_fernet() + self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode() + self.is_extra_encrypted = fernet.is_encrypted + else: + self._extra = value + self.is_extra_encrypted = False + + @declared_attr + def extra(cls): + return synonym('_extra', + descriptor=property(cls.get_extra, cls.set_extra)) + + def get_hook(self): + try: + if self.conn_type == 'mysql': + from airflow.hooks.mysql_hook import MySqlHook + return MySqlHook(mysql_conn_id=self.conn_id) + elif self.conn_type == 'google_cloud_platform': + from airflow.contrib.hooks.bigquery_hook import BigQueryHook + return BigQueryHook(bigquery_conn_id=self.conn_id) + elif self.conn_type == 'postgres': + from airflow.hooks.postgres_hook import PostgresHook + return PostgresHook(postgres_conn_id=self.conn_id) + elif self.conn_type == 'hive_cli': + from airflow.hooks.hive_hooks import HiveCliHook + return HiveCliHook(hive_cli_conn_id=self.conn_id) + elif self.conn_type == 'presto': + from airflow.hooks.presto_hook import PrestoHook + return PrestoHook(presto_conn_id=self.conn_id) + elif self.conn_type == 'hiveserver2': + from airflow.hooks.hive_hooks import HiveServer2Hook + return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) + elif self.conn_type == 'sqlite': + from airflow.hooks.sqlite_hook import SqliteHook + return SqliteHook(sqlite_conn_id=self.conn_id) + elif self.conn_type == 'jdbc': + from airflow.hooks.jdbc_hook import JdbcHook + return JdbcHook(jdbc_conn_id=self.conn_id) + elif self.conn_type == 'mssql': + from airflow.hooks.mssql_hook import MsSqlHook + return MsSqlHook(mssql_conn_id=self.conn_id) + elif self.conn_type == 'oracle': + from airflow.hooks.oracle_hook import OracleHook + return OracleHook(oracle_conn_id=self.conn_id) + elif self.conn_type == 'vertica': + from airflow.contrib.hooks.vertica_hook import VerticaHook + return VerticaHook(vertica_conn_id=self.conn_id) + elif self.conn_type == 'cloudant': + from airflow.contrib.hooks.cloudant_hook import CloudantHook + return CloudantHook(cloudant_conn_id=self.conn_id) + elif self.conn_type == 'jira': + from airflow.contrib.hooks.jira_hook import JiraHook + return JiraHook(jira_conn_id=self.conn_id) + elif self.conn_type == 'redis': + from airflow.contrib.hooks.redis_hook import RedisHook + return RedisHook(redis_conn_id=self.conn_id) + elif self.conn_type == 'wasb': + from airflow.contrib.hooks.wasb_hook import WasbHook + return WasbHook(wasb_conn_id=self.conn_id) + elif self.conn_type == 'docker': + from airflow.hooks.docker_hook import DockerHook + return DockerHook(docker_conn_id=self.conn_id) + elif self.conn_type == 'azure_data_lake': + from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook + return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) + elif self.conn_type == 'azure_cosmos': + from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook + return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) + elif self.conn_type == 'cassandra': + from airflow.contrib.hooks.cassandra_hook import CassandraHook + return CassandraHook(cassandra_conn_id=self.conn_id) + elif self.conn_type == 'mongo': + from airflow.contrib.hooks.mongo_hook import MongoHook + return MongoHook(conn_id=self.conn_id) + elif self.conn_type == 'gcpcloudsql': + from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook + return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) + except Exception: + pass + + def __repr__(self): + return self.conn_id + + def debug_info(self): + return ("id: {}. Host: {}, Port: {}, Schema: {}, " + "Login: {}, Password: {}, extra: {}". + format(self.conn_id, + self.host, + self.port, + self.schema, + self.login, + "XXXXXXXX" if self.password else None, + self.extra_dejson)) + + @property + def extra_dejson(self): + """Returns the extra property by deserializing json.""" + obj = {} + if self.extra: + try: + obj = json.loads(self.extra) + except Exception as e: + self.log.exception(e) + self.log.error("Failed parsing the json for conn_id %s", self.conn_id) + + return obj diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 41c6915ee8..7bd67ce533 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -78,9 +78,8 @@ def wrapper(*args, **kwargs): @provide_session def merge_conn(conn, session=None): - from airflow import models - C = models.Connection - if not session.query(C).filter(C.conn_id == conn.conn_id).first(): + from airflow.models.connection import Connection + if not session.query(Connection).filter(Connection.conn_id == conn.conn_id).first(): session.add(conn) session.commit() @@ -89,133 +88,134 @@ def initdb(rbac=False): session = settings.Session() from airflow import models + from airflow.models.connection import Connection upgradedb() merge_conn( - models.Connection( + Connection( conn_id='airflow_db', conn_type='mysql', host='mysql', login='root', password='', schema='airflow')) merge_conn( - models.Connection( + Connection( conn_id='beeline_default', conn_type='beeline', port="10000", host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}", schema='default')) merge_conn( - models.Connection( + Connection( conn_id='bigquery_default', conn_type='google_cloud_platform', schema='default')) merge_conn( - models.Connection( + Connection( conn_id='local_mysql', conn_type='mysql', host='localhost', login='airflow', password='airflow', schema='airflow')) merge_conn( - models.Connection( + Connection( conn_id='presto_default', conn_type='presto', host='localhost', schema='hive', port=3400)) merge_conn( - models.Connection( + Connection( conn_id='google_cloud_default', conn_type='google_cloud_platform', schema='default',)) merge_conn( - models.Connection( + Connection( conn_id='hive_cli_default', conn_type='hive_cli', schema='default',)) merge_conn( - models.Connection( + Connection( conn_id='hiveserver2_default', conn_type='hiveserver2', host='localhost', schema='default', port=10000)) merge_conn( - models.Connection( + Connection( conn_id='metastore_default', conn_type='hive_metastore', host='localhost', extra="{\"authMechanism\": \"PLAIN\"}", port=9083)) merge_conn( - models.Connection( + Connection( conn_id='mongo_default', conn_type='mongo', host='mongo', port=27017)) merge_conn( - models.Connection( + Connection( conn_id='mysql_default', conn_type='mysql', login='root', schema='airflow', host='mysql')) merge_conn( - models.Connection( + Connection( conn_id='postgres_default', conn_type='postgres', login='postgres', password='airflow', schema='airflow', host='postgres')) merge_conn( - models.Connection( + Connection( conn_id='sqlite_default', conn_type='sqlite', host='/tmp/sqlite_default.db')) merge_conn( - models.Connection( + Connection( conn_id='http_default', conn_type='http', host='https://www.google.com/')) merge_conn( - models.Connection( + Connection( conn_id='mssql_default', conn_type='mssql', host='localhost', port=1433)) merge_conn( - models.Connection( + Connection( conn_id='vertica_default', conn_type='vertica', host='localhost', port=5433)) merge_conn( - models.Connection( + Connection( conn_id='wasb_default', conn_type='wasb', extra='{"sas_token": null}')) merge_conn( - models.Connection( + Connection( conn_id='webhdfs_default', conn_type='hdfs', host='localhost', port=50070)) merge_conn( - models.Connection( + Connection( conn_id='ssh_default', conn_type='ssh', host='localhost')) merge_conn( - models.Connection( + Connection( conn_id='sftp_default', conn_type='sftp', host='localhost', port=22, login='airflow', extra=''' {"key_file": "~/.ssh/id_rsa", "no_host_key_check": true} ''')) merge_conn( - models.Connection( + Connection( conn_id='fs_default', conn_type='fs', extra='{"path": "/"}')) merge_conn( - models.Connection( + Connection( conn_id='aws_default', conn_type='aws', extra='{"region_name": "us-east-1"}')) merge_conn( - models.Connection( + Connection( conn_id='spark_default', conn_type='spark', host='yarn', extra='{"queue": "root.default"}')) merge_conn( - models.Connection( + Connection( conn_id='druid_broker_default', conn_type='druid', host='druid-broker', port=8082, extra='{"endpoint": "druid/v2/sql"}')) merge_conn( - models.Connection( + Connection( conn_id='druid_ingest_default', conn_type='druid', host='druid-overlord', port=8081, extra='{"endpoint": "druid/indexer/v1/task"}')) merge_conn( - models.Connection( + Connection( conn_id='redis_default', conn_type='redis', host='redis', port=6379, extra='{"db": 0}')) merge_conn( - models.Connection( + Connection( conn_id='sqoop_default', conn_type='sqoop', host='rmdbs', extra='')) merge_conn( - models.Connection( + Connection( conn_id='emr_default', conn_type='emr', extra=''' { "Name": "default_job_flow_name", @@ -262,27 +262,27 @@ def initdb(rbac=False): } ''')) merge_conn( - models.Connection( + Connection( conn_id='databricks_default', conn_type='databricks', host='localhost')) merge_conn( - models.Connection( + Connection( conn_id='qubole_default', conn_type='qubole', host='localhost')) merge_conn( - models.Connection( + Connection( conn_id='segment_default', conn_type='segment', extra='{"write_key": "my-segment-write-key"}')), merge_conn( - models.Connection( + Connection( conn_id='azure_data_lake_default', conn_type='azure_data_lake', extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }')) merge_conn( - models.Connection( + Connection( conn_id='azure_cosmos_default', conn_type='azure_cosmos', extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }')) merge_conn( - models.Connection( + Connection( conn_id='cassandra_default', conn_type='cassandra', host='cassandra', port=9042)) diff --git a/airflow/www/app.py b/airflow/www/app.py index 98e3003908..d7f102249b 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -30,6 +30,7 @@ import airflow from airflow import configuration as conf from airflow import models, LoggingMixin +from airflow.models.connection import Connection from airflow.settings import Session from airflow.www.blueprints import routes @@ -107,7 +108,7 @@ def create_app(config=None, testing=False): av(vs.UserModelView( models.User, Session, name="Users", category="Admin")) av(vs.ConnectionModelView( - models.Connection, Session, name="Connections", category="Admin")) + Connection, Session, name="Connections", category="Admin")) av(vs.VariableView( models.Variable, Session, name="Variables", category="Admin")) av(vs.XComView( diff --git a/airflow/www/views.py b/airflow/www/views.py index 8792191e21..d9edbefabc 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -69,6 +69,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.models import XCom, DagRun +from airflow.models.connection import Connection from airflow.operators.subdag_operator import SubDagOperator from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, SCHEDULER_DEPS from airflow.utils import timezone @@ -388,7 +389,7 @@ def chart_data(self): csv = request.args.get('csv') == "true" chart = session.query(models.Chart).filter_by(id=chart_id).first() db = session.query( - models.Connection).filter_by(conn_id=chart.conn_id).first() + Connection).filter_by(conn_id=chart.conn_id).first() payload = { "state": "ERROR", @@ -2216,8 +2217,7 @@ class QueryView(wwwutils.DataProfilingMixin, BaseView): @wwwutils.gzipped @provide_session def query(self, session=None): - dbs = session.query(models.Connection).order_by( - models.Connection.conn_id).all() + dbs = session.query(Connection).order_by(Connection.conn_id).all() session.expunge_all() db_choices = list( ((db.conn_id, db.conn_id) for db in dbs if db.get_hook())) @@ -2335,8 +2335,8 @@ class SlaMissModelView(wwwutils.SuperUserMixin, ModelViewOnly): def _connection_ids(session=None): return [(c.conn_id, c.conn_id) for c in ( session - .query(models.Connection.conn_id) - .group_by(models.Connection.conn_id))] + .query(Connection.conn_id) + .group_by(Connection.conn_id))] class ChartModelView(wwwutils.DataProfilingMixin, AirflowModelView): @@ -2986,7 +2986,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView): 'extra__google_cloud_platform__scope': StringField('Scopes (comma separated)'), } form_choices = { - 'conn_type': models.Connection._types + 'conn_type': Connection._types } def on_model_change(self, form, model, is_created): diff --git a/airflow/www_rbac/forms.py b/airflow/www_rbac/forms.py index 61c34888e3..0a36a90ef3 100644 --- a/airflow/www_rbac/forms.py +++ b/airflow/www_rbac/forms.py @@ -22,7 +22,7 @@ from __future__ import print_function from __future__ import unicode_literals -from airflow import models +from airflow.models.connection import Connection from airflow.utils import timezone from flask_appbuilder.forms import DynamicForm @@ -93,7 +93,7 @@ class ConnectionForm(DynamicForm): widget=BS3TextFieldWidget()) conn_type = SelectField( lazy_gettext('Conn Type'), - choices=models.Connection._types, + choices=Connection._types, widget=Select2Widget()) host = StringField( lazy_gettext('Host'), diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index 49a9a734cc..31efc4459b 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -55,6 +55,7 @@ from airflow.api.common.experimental.mark_tasks import (set_dag_run_state_to_success, set_dag_run_state_to_failed) from airflow.models import XCom, DagRun +from airflow.models.connection import Connection from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, SCHEDULER_DEPS from airflow.utils import timezone from airflow.utils.dates import infer_time_unit, scale_time_units @@ -1898,7 +1899,7 @@ def action_muldelete(self, items): class ConnectionModelView(AirflowModelView): route_base = '/connection' - datamodel = AirflowModelView.CustomSQLAInterface(models.Connection) + datamodel = AirflowModelView.CustomSQLAInterface(Connection) base_permissions = ['can_add', 'can_list', 'can_edit', 'can_delete'] diff --git a/docs/concepts.rst b/docs/concepts.rst index 8a497e499c..eac7a8a7f1 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -303,7 +303,7 @@ Hooks Hooks are interfaces to external platforms and databases like Hive, S3, MySQL, Postgres, HDFS, and Pig. Hooks implement a common interface when possible, and act as a building block for operators. They also use -the ``airflow.models.Connection`` model to retrieve hostnames +the ``airflow.models.connection.Connection`` model to retrieve hostnames and authentication information. Hooks keep authentication code and information out of pipelines, centralized in the metadata database. diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py index addee85109..f842b44e90 100644 --- a/tests/contrib/hooks/test_aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -23,7 +23,7 @@ import boto3 from airflow import configuration -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.contrib.hooks.aws_hook import AwsHook try: diff --git a/tests/contrib/hooks/test_azure_cosmos_hook.py b/tests/contrib/hooks/test_azure_cosmos_hook.py index 653242a34b..4c926b5d89 100644 --- a/tests/contrib/hooks/test_azure_cosmos_hook.py +++ b/tests/contrib/hooks/test_azure_cosmos_hook.py @@ -1,202 +1,202 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - - -import json -import unittest -import uuid - -from airflow.exceptions import AirflowException -from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook - -from airflow import configuration -from airflow import models -from airflow.utils import db - -import logging - -try: - from unittest import mock - -except ImportError: - try: - import mock - except ImportError: - mock = None - - -class TestAzureCosmosDbHook(unittest.TestCase): - - # Set up an environment to test with - def setUp(self): - # set up some test variables - self.test_end_point = 'https://test_endpoint:443' - self.test_master_key = 'magic_test_key' - self.test_database_name = 'test_database_name' - self.test_collection_name = 'test_collection_name' - self.test_database_default = 'test_database_default' - self.test_collection_default = 'test_collection_default' - configuration.load_test_config() - db.merge_conn( - models.Connection( - conn_id='azure_cosmos_test_key_id', - conn_type='azure_cosmos', - login=self.test_end_point, - password=self.test_master_key, - extra=json.dumps({'database_name': self.test_database_default, - 'collection_name': self.test_collection_default}) - ) - ) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_create_database(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.create_database(self.test_database_name) - expected_calls = [mock.call().CreateDatabase({'id': self.test_database_name})] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_create_database_exception(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.assertRaises(AirflowException, self.cosmos.create_database, None) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_create_container_exception(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.assertRaises(AirflowException, self.cosmos.create_collection, None) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_create_container(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.create_collection(self.test_collection_name, self.test_database_name) - expected_calls = [mock.call().CreateContainer( - 'dbs/test_database_name', - {'id': self.test_collection_name})] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_create_container_default(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.create_collection(self.test_collection_name) - expected_calls = [mock.call().CreateContainer( - 'dbs/test_database_default', - {'id': self.test_collection_name})] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_upsert_document_default(self, cosmos_mock): - test_id = str(uuid.uuid4()) - cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - returned_item = self.cosmos.upsert_document({'id': test_id}) - expected_calls = [mock.call().CreateItem( - 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'id': test_id})] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - logging.getLogger().info(returned_item) - self.assertEqual(returned_item['id'], test_id) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_upsert_document(self, cosmos_mock): - test_id = str(uuid.uuid4()) - cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - returned_item = self.cosmos.upsert_document( - {'data1': 'somedata'}, - database_name=self.test_database_name, - collection_name=self.test_collection_name, - document_id=test_id) - - expected_calls = [mock.call().CreateItem( - 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, - {'data1': 'somedata', 'id': test_id})] - - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - logging.getLogger().info(returned_item) - self.assertEqual(returned_item['id'], test_id) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_insert_documents(self, cosmos_mock): - test_id1 = str(uuid.uuid4()) - test_id2 = str(uuid.uuid4()) - test_id3 = str(uuid.uuid4()) - documents = [ - {'id': test_id1, 'data': 'data1'}, - {'id': test_id2, 'data': 'data2'}, - {'id': test_id3, 'data': 'data3'}] - - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - returned_item = self.cosmos.insert_documents(documents) - expected_calls = [ - mock.call().CreateItem( - 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data1', 'id': test_id1}), - mock.call().CreateItem( - 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data2', 'id': test_id2}), - mock.call().CreateItem( - 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, - {'data': 'data3', 'id': test_id3})] - logging.getLogger().info(returned_item) - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_delete_database(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.delete_database(self.test_database_name) - expected_calls = [mock.call().DeleteDatabase('dbs/test_database_name')] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_delete_database_exception(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.assertRaises(AirflowException, self.cosmos.delete_database, None) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_delete_container_exception(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.assertRaises(AirflowException, self.cosmos.delete_collection, None) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_delete_container(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.delete_collection(self.test_collection_name, self.test_database_name) - expected_calls = [mock.call().DeleteContainer('dbs/test_database_name/colls/test_collection_name')] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - @mock.patch('azure.cosmos.cosmos_client.CosmosClient') - def test_delete_container_default(self, cosmos_mock): - self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') - self.cosmos.delete_collection(self.test_collection_name) - expected_calls = [mock.call().DeleteContainer('dbs/test_database_default/colls/test_collection_name')] - cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) - cosmos_mock.assert_has_calls(expected_calls) - - -if __name__ == '__main__': - unittest.main() +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import json +import unittest +import uuid + +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook + +from airflow import configuration +from airflow.models.connection import Connection +from airflow.utils import db + +import logging + +try: + from unittest import mock + +except ImportError: + try: + import mock + except ImportError: + mock = None + + +class TestAzureCosmosDbHook(unittest.TestCase): + + # Set up an environment to test with + def setUp(self): + # set up some test variables + self.test_end_point = 'https://test_endpoint:443' + self.test_master_key = 'magic_test_key' + self.test_database_name = 'test_database_name' + self.test_collection_name = 'test_collection_name' + self.test_database_default = 'test_database_default' + self.test_collection_default = 'test_collection_default' + configuration.load_test_config() + db.merge_conn( + Connection( + conn_id='azure_cosmos_test_key_id', + conn_type='azure_cosmos', + login=self.test_end_point, + password=self.test_master_key, + extra=json.dumps({'database_name': self.test_database_default, + 'collection_name': self.test_collection_default}) + ) + ) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_database(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_database(self.test_database_name) + expected_calls = [mock.call().CreateDatabase({'id': self.test_database_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_database_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.create_database, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.create_collection, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_collection(self.test_collection_name, self.test_database_name) + expected_calls = [mock.call().CreateContainer( + 'dbs/test_database_name', + {'id': self.test_collection_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_create_container_default(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.create_collection(self.test_collection_name) + expected_calls = [mock.call().CreateContainer( + 'dbs/test_database_default', + {'id': self.test_collection_name})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_upsert_document_default(self, cosmos_mock): + test_id = str(uuid.uuid4()) + cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.upsert_document({'id': test_id}) + expected_calls = [mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'id': test_id})] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + logging.getLogger().info(returned_item) + self.assertEqual(returned_item['id'], test_id) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_upsert_document(self, cosmos_mock): + test_id = str(uuid.uuid4()) + cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.upsert_document( + {'data1': 'somedata'}, + database_name=self.test_database_name, + collection_name=self.test_collection_name, + document_id=test_id) + + expected_calls = [mock.call().CreateItem( + 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, + {'data1': 'somedata', 'id': test_id})] + + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + logging.getLogger().info(returned_item) + self.assertEqual(returned_item['id'], test_id) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_insert_documents(self, cosmos_mock): + test_id1 = str(uuid.uuid4()) + test_id2 = str(uuid.uuid4()) + test_id3 = str(uuid.uuid4()) + documents = [ + {'id': test_id1, 'data': 'data1'}, + {'id': test_id2, 'data': 'data2'}, + {'id': test_id3, 'data': 'data3'}] + + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + returned_item = self.cosmos.insert_documents(documents) + expected_calls = [ + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data1', 'id': test_id1}), + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data2', 'id': test_id2}), + mock.call().CreateItem( + 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, + {'data': 'data3', 'id': test_id3})] + logging.getLogger().info(returned_item) + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_database(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_database(self.test_database_name) + expected_calls = [mock.call().DeleteDatabase('dbs/test_database_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_database_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.delete_database, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container_exception(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.assertRaises(AirflowException, self.cosmos.delete_collection, None) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_collection(self.test_collection_name, self.test_database_name) + expected_calls = [mock.call().DeleteContainer('dbs/test_database_name/colls/test_collection_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + @mock.patch('azure.cosmos.cosmos_client.CosmosClient') + def test_delete_container_default(self, cosmos_mock): + self.cosmos = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') + self.cosmos.delete_collection(self.test_collection_name) + expected_calls = [mock.call().DeleteContainer('dbs/test_database_default/colls/test_collection_name')] + cosmos_mock.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) + cosmos_mock.assert_has_calls(expected_calls) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/hooks/test_azure_data_lake_hook.py b/tests/contrib/hooks/test_azure_data_lake_hook.py index af26f85d99..797d038b88 100644 --- a/tests/contrib/hooks/test_azure_data_lake_hook.py +++ b/tests/contrib/hooks/test_azure_data_lake_hook.py @@ -23,7 +23,7 @@ import unittest from airflow import configuration -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db try: @@ -40,7 +40,7 @@ class TestAzureDataLakeHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='adl_test_key', conn_type='azure_data_lake', login='client_id', diff --git a/tests/contrib/hooks/test_azure_fileshare_hook.py b/tests/contrib/hooks/test_azure_fileshare_hook.py index 5803cd83f7..43c075e946 100644 --- a/tests/contrib/hooks/test_azure_fileshare_hook.py +++ b/tests/contrib/hooks/test_azure_fileshare_hook.py @@ -23,8 +23,8 @@ import unittest from airflow import configuration -from airflow import models from airflow.contrib.hooks.azure_fileshare_hook import AzureFileShareHook +from airflow.models.connection import Connection from airflow.utils import db @@ -42,13 +42,13 @@ class TestAzureFileshareHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='wasb_test_key', conn_type='wasb', login='login', password='key' ) ) db.merge_conn( - models.Connection( + Connection( conn_id='wasb_test_sas_token', conn_type='wasb', login='login', extra=json.dumps({'sas_token': 'token'}) ) diff --git a/tests/contrib/hooks/test_cassandra_hook.py b/tests/contrib/hooks/test_cassandra_hook.py index 73dac4f3b4..687d3169bb 100644 --- a/tests/contrib/hooks/test_cassandra_hook.py +++ b/tests/contrib/hooks/test_cassandra_hook.py @@ -27,7 +27,7 @@ from cassandra.policies import ( TokenAwarePolicy, RoundRobinPolicy, DCAwareRoundRobinPolicy, WhiteListRoundRobinPolicy ) -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db @@ -35,12 +35,12 @@ class CassandraHookTest(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='cassandra_test', conn_type='cassandra', host='host-1,host-2', port='9042', schema='test_keyspace', extra='{"load_balancing_policy":"TokenAwarePolicy"}')) db.merge_conn( - models.Connection( + Connection( conn_id='cassandra_default_with_schema', conn_type='cassandra', host='cassandra', port='9042', schema='s')) diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 597c881929..4ebc78c664 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -31,7 +31,7 @@ SUBMIT_RUN_ENDPOINT ) from airflow.exceptions import AirflowException -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.utils import db try: diff --git a/tests/contrib/hooks/test_discord_webhook_hook.py b/tests/contrib/hooks/test_discord_webhook_hook.py index c6b3eac00f..e41d88cf45 100644 --- a/tests/contrib/hooks/test_discord_webhook_hook.py +++ b/tests/contrib/hooks/test_discord_webhook_hook.py @@ -20,7 +20,8 @@ import json import unittest -from airflow import configuration, models, AirflowException +from airflow import configuration, AirflowException +from airflow.models.connection import Connection from airflow.utils import db from airflow.contrib.hooks.discord_webhook_hook import DiscordWebhookHook @@ -50,7 +51,7 @@ class TestDiscordWebhookHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='default-discord-webhook', host='https://discordapp.com/api/', extra='{"webhook_endpoint": "webhooks/00000/some-discord-token_000"}') diff --git a/tests/contrib/hooks/test_imap_hook.py b/tests/contrib/hooks/test_imap_hook.py index b4e4ff3ed7..579fda3c79 100644 --- a/tests/contrib/hooks/test_imap_hook.py +++ b/tests/contrib/hooks/test_imap_hook.py @@ -22,8 +22,9 @@ from mock import Mock, patch, mock_open -from airflow import configuration, models +from airflow import configuration from airflow.contrib.hooks.imap_hook import ImapHook +from airflow.models.connection import Connection from airflow.utils import db imaplib_string = 'airflow.contrib.hooks.imap_hook.imaplib' @@ -56,7 +57,7 @@ def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='imap_default', host='imap_server_address', login='imap_user', diff --git a/tests/contrib/hooks/test_jdbc_hook.py b/tests/contrib/hooks/test_jdbc_hook.py index 3f708997d9..1779b781b7 100644 --- a/tests/contrib/hooks/test_jdbc_hook.py +++ b/tests/contrib/hooks/test_jdbc_hook.py @@ -26,7 +26,7 @@ from airflow import configuration from airflow.hooks.jdbc_hook import JdbcHook -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db jdbc_conn_mock = Mock( @@ -38,7 +38,7 @@ class TestJdbcHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='jdbc_default', conn_type='jdbc', host='jdbc://localhost/', port=443, extra=json.dumps({"extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2", diff --git a/tests/contrib/hooks/test_jira_hook.py b/tests/contrib/hooks/test_jira_hook.py index 378c379d55..26aae945a0 100644 --- a/tests/contrib/hooks/test_jira_hook.py +++ b/tests/contrib/hooks/test_jira_hook.py @@ -25,7 +25,7 @@ from airflow import configuration from airflow.contrib.hooks.jira_hook import JiraHook -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db jira_client_mock = Mock( @@ -37,7 +37,7 @@ class TestJiraHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='jira_default', conn_type='jira', host='https://localhost/jira/', port=443, extra='{"verify": "False", "project": "AIRFLOW"}')) diff --git a/tests/contrib/hooks/test_openfaas_hook.py b/tests/contrib/hooks/test_openfaas_hook.py index 28f1158fbf..7649e6a320 100644 --- a/tests/contrib/hooks/test_openfaas_hook.py +++ b/tests/contrib/hooks/test_openfaas_hook.py @@ -20,7 +20,7 @@ import unittest import requests_mock -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.contrib.hooks.openfaas_hook import OpenFaasHook from airflow.hooks.base_hook import BaseHook from airflow import configuration, AirflowException diff --git a/tests/contrib/hooks/test_sftp_hook.py b/tests/contrib/hooks/test_sftp_hook.py index ac4d78e9b1..bcdcb80dc3 100644 --- a/tests/contrib/hooks/test_sftp_hook.py +++ b/tests/contrib/hooks/test_sftp_hook.py @@ -25,8 +25,9 @@ import os import pysftp -from airflow import configuration, models +from airflow import configuration from airflow.contrib.hooks.sftp_hook import SFTPHook +from airflow.models.connection import Connection TMP_PATH = '/tmp' TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' @@ -108,14 +109,14 @@ def test_get_mod_time(self): @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_default(self, get_connection): - connection = models.Connection(login='login', host='host') + connection = Connection(login='login', host='host') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_enabled(self, get_connection): - connection = models.Connection( + connection = Connection( login='login', host='host', extra='{"no_host_key_check": true}') @@ -125,7 +126,7 @@ def test_no_host_key_check_enabled(self, get_connection): @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_disabled(self, get_connection): - connection = models.Connection( + connection = Connection( login='login', host='host', extra='{"no_host_key_check": false}') @@ -135,7 +136,7 @@ def test_no_host_key_check_disabled(self, get_connection): @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): - connection = models.Connection( + connection = Connection( login='login', host='host', extra='{"no_host_key_check": "foo"}') @@ -145,7 +146,7 @@ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_ignore(self, get_connection): - connection = models.Connection( + connection = Connection( login='login', host='host', extra='{"ignore_hostkey_verification": true}') @@ -155,7 +156,7 @@ def test_no_host_key_check_ignore(self, get_connection): @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_no_ignore(self, get_connection): - connection = models.Connection( + connection = Connection( login='login', host='host', extra='{"ignore_hostkey_verification": false}') diff --git a/tests/contrib/hooks/test_slack_webhook_hook.py b/tests/contrib/hooks/test_slack_webhook_hook.py index a26204f39e..1ceed70676 100644 --- a/tests/contrib/hooks/test_slack_webhook_hook.py +++ b/tests/contrib/hooks/test_slack_webhook_hook.py @@ -20,7 +20,8 @@ import json import unittest -from airflow import configuration, models +from airflow import configuration +from airflow.models.connection import Connection from airflow.utils import db from airflow.contrib.hooks.slack_webhook_hook import SlackWebhookHook @@ -49,7 +50,7 @@ class TestSlackWebhookHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='slack-webhook-default', extra='{"webhook_token": "your_token_here"}') ) diff --git a/tests/contrib/hooks/test_spark_jdbc_hook.py b/tests/contrib/hooks/test_spark_jdbc_hook.py index a6202dcfb9..f7979f2551 100644 --- a/tests/contrib/hooks/test_spark_jdbc_hook.py +++ b/tests/contrib/hooks/test_spark_jdbc_hook.py @@ -19,7 +19,8 @@ # import unittest -from airflow import configuration, models +from airflow import configuration +from airflow.models.connection import Connection from airflow.utils import db from airflow.contrib.hooks.spark_jdbc_hook import SparkJDBCHook @@ -67,13 +68,13 @@ class TestSparkJDBCHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='spark-default', conn_type='spark', host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='jdbc-default', conn_type='postgres', host='localhost', schema='default', port=5432, login='user', password='supersecret', diff --git a/tests/contrib/hooks/test_spark_sql_hook.py b/tests/contrib/hooks/test_spark_sql_hook.py index f76768efcd..04d5439b9b 100644 --- a/tests/contrib/hooks/test_spark_sql_hook.py +++ b/tests/contrib/hooks/test_spark_sql_hook.py @@ -24,7 +24,8 @@ from mock import patch, call -from airflow import configuration, models +from airflow import configuration +from airflow.models.connection import Connection from airflow.utils import db from airflow.contrib.hooks.spark_sql_hook import SparkSqlHook @@ -53,7 +54,7 @@ def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='spark_default', conn_type='spark', host='yarn://yarn-master') ) diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py index 1bdcda56ec..a31e79cdcc 100644 --- a/tests/contrib/hooks/test_spark_submit_hook.py +++ b/tests/contrib/hooks/test_spark_submit_hook.py @@ -20,7 +20,8 @@ import six import unittest -from airflow import configuration, models, AirflowException +from airflow import configuration, AirflowException +from airflow.models.connection import Connection from airflow.utils import db from mock import patch, call @@ -72,13 +73,13 @@ def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='spark_yarn_cluster', conn_type='spark', host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_k8s_cluster', conn_type='spark', host='k8s://https://k8s-master', extra='{"spark-home": "/opt/spark", ' + @@ -86,43 +87,43 @@ def setUp(self): '"namespace": "mynamespace"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_default_mesos', conn_type='spark', host='mesos://host', port=5050) ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_home_set', conn_type='spark', host='yarn://yarn-master', extra='{"spark-home": "/opt/myspark"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_home_not_set', conn_type='spark', host='yarn://yarn-master') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_binary_set', conn_type='spark', host='yarn', extra='{"spark-binary": "custom-spark-submit"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_binary_and_home_set', conn_type='spark', host='yarn', extra='{"spark-home": "/path/to/spark_home", ' + '"spark-binary": "custom-spark-submit"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_standalone_cluster', conn_type='spark', host='spark://spark-standalone-master:6066', extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}') ) db.merge_conn( - models.Connection( + Connection( conn_id='spark_standalone_cluster_client_mode', conn_type='spark', host='spark://spark-standalone-master:6066', extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "client"}') diff --git a/tests/contrib/hooks/test_sqoop_hook.py b/tests/contrib/hooks/test_sqoop_hook.py index 8bef5a4937..64c7e0c8a5 100644 --- a/tests/contrib/hooks/test_sqoop_hook.py +++ b/tests/contrib/hooks/test_sqoop_hook.py @@ -22,9 +22,10 @@ import json import unittest -from airflow import configuration, models +from airflow import configuration from airflow.contrib.hooks.sqoop_hook import SqoopHook from airflow.exceptions import AirflowException +from airflow.models.connection import Connection from airflow.utils import db from mock import patch, call @@ -86,7 +87,7 @@ class TestSqoopHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='sqoop_test', conn_type='sqoop', schema='schema', host='rmdbs', port=5050, extra=json.dumps(self._config_json) ) diff --git a/tests/contrib/hooks/test_ssh_hook.py b/tests/contrib/hooks/test_ssh_hook.py index 64b015d09e..8058984892 100644 --- a/tests/contrib/hooks/test_ssh_hook.py +++ b/tests/contrib/hooks/test_ssh_hook.py @@ -19,8 +19,8 @@ import unittest from airflow import configuration +from airflow.models.connection import Connection from airflow.utils import db -from airflow import models try: from unittest import mock @@ -132,7 +132,7 @@ def test_tunnel_without_password(self, ssh_mock): def test_conn_with_extra_parameters(self): db.merge_conn( - models.Connection( + Connection( conn_id='ssh_with_extra', host='localhost', conn_type='ssh', diff --git a/tests/contrib/hooks/test_wasb_hook.py b/tests/contrib/hooks/test_wasb_hook.py index 88481440e7..07013528dd 100644 --- a/tests/contrib/hooks/test_wasb_hook.py +++ b/tests/contrib/hooks/test_wasb_hook.py @@ -24,8 +24,8 @@ from collections import namedtuple from airflow import configuration, AirflowException -from airflow import models from airflow.contrib.hooks.wasb_hook import WasbHook +from airflow.models.connection import Connection from airflow.utils import db try: @@ -42,13 +42,13 @@ class TestWasbHook(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='wasb_test_key', conn_type='wasb', login='login', password='key' ) ) db.merge_conn( - models.Connection( + Connection( conn_id='wasb_test_sas_token', conn_type='wasb', login='login', extra=json.dumps({'sas_token': 'token'}) ) diff --git a/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py b/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py index e6e1abe374..14439514f8 100644 --- a/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py +++ b/tests/contrib/operators/test_azure_cosmos_insertdocument_operator.py @@ -26,7 +26,7 @@ from airflow.contrib.operators.azure_cosmos_operator import AzureCosmosInsertDocumentOperator from airflow import configuration -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db try: @@ -50,7 +50,7 @@ def setUp(self): self.test_collection_name = 'test_collection_name' configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='azure_cosmos_test_key_id', conn_type='azure_cosmos', login=self.test_end_point, diff --git a/tests/contrib/operators/test_gcp_base.py b/tests/contrib/operators/test_gcp_base.py index 7e786c5b4b..2df3e39ebd 100644 --- a/tests/contrib/operators/test_gcp_base.py +++ b/tests/contrib/operators/test_gcp_base.py @@ -22,6 +22,7 @@ import unittest from airflow import models, settings, configuration, AirflowException +from airflow.models.connection import Connection from airflow.utils.timezone import datetime DEFAULT_DATE = datetime(2015, 1, 1) @@ -97,8 +98,8 @@ def _gcp_authenticate(self): def update_connection_with_key_path(self): session = settings.Session() try: - conn = session.query(models.Connection).filter( - models.Connection.conn_id == 'google_cloud_default')[0] + conn = session.query(Connection).filter( + Connection.conn_id == 'google_cloud_default')[0] extras = conn.extra_dejson extras[KEYPATH_EXTRA] = self.full_key_path if extras.get(KEYFILE_DICT_EXTRA): @@ -117,8 +118,8 @@ def update_connection_with_key_path(self): def update_connection_with_dictionary(self): session = settings.Session() try: - conn = session.query(models.Connection).filter( - models.Connection.conn_id == 'google_cloud_default')[0] + conn = session.query(Connection).filter( + Connection.conn_id == 'google_cloud_default')[0] extras = conn.extra_dejson with open(self.full_key_path, "r") as f: content = json.load(f) diff --git a/tests/contrib/operators/test_gcp_sql_operator.py b/tests/contrib/operators/test_gcp_sql_operator.py index 9f631493e0..8c2a4aa95c 100644 --- a/tests/contrib/operators/test_gcp_sql_operator.py +++ b/tests/contrib/operators/test_gcp_sql_operator.py @@ -31,7 +31,7 @@ CloudSqlInstanceDatabaseCreateOperator, CloudSqlInstanceDatabasePatchOperator, \ CloudSqlInstanceExportOperator, CloudSqlInstanceImportOperator, \ CloudSqlInstanceDatabaseDeleteOperator, CloudSqlQueryOperator -from airflow.models import Connection +from airflow.models.connection import Connection from tests.contrib.operators.test_gcp_base import BaseGcpIntegrationTestCase, \ GCP_CLOUDSQL_KEY, SKIP_TEST_WARNING diff --git a/tests/contrib/operators/test_jira_operator_test.py b/tests/contrib/operators/test_jira_operator_test.py index 2509038a36..5d15832525 100644 --- a/tests/contrib/operators/test_jira_operator_test.py +++ b/tests/contrib/operators/test_jira_operator_test.py @@ -25,7 +25,7 @@ from airflow import DAG, configuration from airflow.contrib.operators.jira_operator import JiraOperator -from airflow import models +from airflow.models.connection import Connection from airflow.utils import db from airflow.utils import timezone @@ -58,7 +58,7 @@ def setUp(self): dag = DAG('test_dag_id', default_args=args) self.dag = dag db.merge_conn( - models.Connection( + Connection( conn_id='jira_default', conn_type='jira', host='https://localhost/jira/', port=443, extra='{"verify": "False", "project": "AIRFLOW"}')) diff --git a/tests/contrib/operators/test_qubole_operator.py b/tests/contrib/operators/test_qubole_operator.py index c0894c0ba7..43d20d2750 100644 --- a/tests/contrib/operators/test_qubole_operator.py +++ b/tests/contrib/operators/test_qubole_operator.py @@ -21,7 +21,8 @@ import unittest from datetime import datetime -from airflow.models import DAG, Connection, TaskInstance +from airflow.models import DAG, TaskInstance +from airflow.models.connection import Connection from airflow.utils import db from airflow.contrib.hooks.qubole_hook import QuboleHook diff --git a/tests/contrib/sensors/test_datadog_sensor.py b/tests/contrib/sensors/test_datadog_sensor.py index 75fca65033..d5d9309666 100644 --- a/tests/contrib/sensors/test_datadog_sensor.py +++ b/tests/contrib/sensors/test_datadog_sensor.py @@ -23,8 +23,8 @@ from mock import patch from airflow import configuration -from airflow import models from airflow.contrib.sensors.datadog_sensor import DatadogSensor +from airflow.models.connection import Connection from airflow.utils import db at_least_one_event = [{'alert_type': 'info', @@ -64,7 +64,7 @@ class TestDatadogSensor(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='datadog_default', conn_type='datadog', login='login', password='password', extra=json.dumps({'api_key': 'api_key', 'app_key': 'app_key'}) diff --git a/tests/contrib/sensors/test_imap_attachment_sensor.py b/tests/contrib/sensors/test_imap_attachment_sensor.py index 297de02a99..66e3cfc788 100644 --- a/tests/contrib/sensors/test_imap_attachment_sensor.py +++ b/tests/contrib/sensors/test_imap_attachment_sensor.py @@ -21,8 +21,9 @@ from mock import patch, Mock -from airflow import configuration, models +from airflow import configuration from airflow.contrib.sensors.imap_attachment_sensor import ImapAttachmentSensor +from airflow.models.connection import Connection from airflow.utils import db imap_hook_string = 'airflow.contrib.sensors.imap_attachment_sensor.ImapHook' @@ -33,7 +34,7 @@ class TestImapAttachmentSensor(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='imap_test', host='base_url', login='user', diff --git a/tests/contrib/sensors/test_jira_sensor_test.py b/tests/contrib/sensors/test_jira_sensor_test.py index 32aa235851..edbcb4b57b 100644 --- a/tests/contrib/sensors/test_jira_sensor_test.py +++ b/tests/contrib/sensors/test_jira_sensor_test.py @@ -24,8 +24,8 @@ from mock import patch from airflow import DAG, configuration -from airflow import models from airflow.contrib.sensors.jira_sensor import JiraTicketSensor +from airflow.models.connection import Connection from airflow.utils import db, timezone DEFAULT_DATE = timezone.datetime(2017, 1, 1) @@ -57,7 +57,7 @@ def setUp(self): dag = DAG('test_dag_id', default_args=args) self.dag = dag db.merge_conn( - models.Connection( + Connection( conn_id='jira_default', conn_type='jira', host='https://localhost/jira/', port=443, extra='{"verify": "False", "project": "AIRFLOW"}')) diff --git a/tests/contrib/sensors/test_mongo_sensor.py b/tests/contrib/sensors/test_mongo_sensor.py index 6a78b7d146..e2110002ba 100644 --- a/tests/contrib/sensors/test_mongo_sensor.py +++ b/tests/contrib/sensors/test_mongo_sensor.py @@ -24,7 +24,7 @@ from airflow import configuration from airflow.contrib.hooks.mongo_hook import MongoHook from airflow.contrib.sensors.mongo_sensor import MongoSensor -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.utils import db, timezone diff --git a/tests/contrib/sensors/test_qubole_sensor.py b/tests/contrib/sensors/test_qubole_sensor.py index 9a0365d021..bc9ad8f742 100644 --- a/tests/contrib/sensors/test_qubole_sensor.py +++ b/tests/contrib/sensors/test_qubole_sensor.py @@ -25,7 +25,8 @@ from airflow.contrib.sensors.qubole_sensor import QuboleFileSensor, QubolePartitionSensor from airflow.exceptions import AirflowException -from airflow.models import DAG, Connection +from airflow.models import DAG +from airflow.models.connection import Connection from airflow.utils import db DAG_ID = "qubole_test_dag" diff --git a/tests/core.py b/tests/core.py index 3cf4c4e18f..efae7f4b1e 100644 --- a/tests/core.py +++ b/tests/core.py @@ -48,6 +48,7 @@ from airflow import jobs, models, DAG, utils, macros, settings, exceptions from airflow.models import BaseOperator +from airflow.models.connection import Connection from airflow.operators.bash_operator import BashOperator from airflow.operators.check_operator import CheckOperator, ValueCheckOperator from airflow.operators.dagrun_operator import TriggerDagRunOperator @@ -1322,8 +1323,8 @@ def test_cli_connections_add_delete(self): for index in range(1, 6): conn_id = 'new%s' % index result = (session - .query(models.Connection) - .filter(models.Connection.conn_id == conn_id) + .query(Connection) + .filter(Connection.conn_id == conn_id) .first()) result = (result.conn_id, result.conn_type, result.host, result.port, result.get_extra()) @@ -1368,8 +1369,8 @@ def test_cli_connections_add_delete(self): # Check deletions for index in range(1, 7): conn_id = 'new%s' % index - result = (session.query(models.Connection) - .filter(models.Connection.conn_id == conn_id) + result = (session.query(Connection) + .filter(Connection.conn_id == conn_id) .first()) self.assertTrue(result is None) @@ -2520,9 +2521,9 @@ def test_using_unix_socket_env_var(self): self.assertIsNone(c.port) def test_param_setup(self): - c = models.Connection(conn_id='local_mysql', conn_type='mysql', - host='localhost', login='airflow', - password='airflow', schema='airflow') + c = Connection(conn_id='local_mysql', conn_type='mysql', + host='localhost', login='airflow', + password='airflow', schema='airflow') self.assertEqual('localhost', c.host) self.assertEqual('airflow', c.schema) self.assertEqual('airflow', c.login) @@ -2614,9 +2615,9 @@ def test_get_client(self): @mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections') def test_get_autoconfig_client(self, mock_get_connections, MockAutoConfigClient): - c = models.Connection(conn_id='hdfs', conn_type='hdfs', - host='localhost', port=8020, login='foo', - extra=json.dumps({'autoconfig': True})) + c = Connection(conn_id='hdfs', conn_type='hdfs', + host='localhost', port=8020, login='foo', + extra=json.dumps({'autoconfig': True})) mock_get_connections.return_value = [c] HDFSHook(hdfs_conn_id='hdfs').get_conn() MockAutoConfigClient.assert_called_once_with(effective_user='foo', @@ -2630,10 +2631,10 @@ def test_get_autoconfig_client_no_conn(self, MockAutoConfigClient): @mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections') def test_get_ha_client(self, mock_get_connections): - c1 = models.Connection(conn_id='hdfs_default', conn_type='hdfs', - host='localhost', port=8020) - c2 = models.Connection(conn_id='hdfs_default', conn_type='hdfs', - host='localhost2', port=8020) + c1 = Connection(conn_id='hdfs_default', conn_type='hdfs', + host='localhost', port=8020) + c2 = Connection(conn_id='hdfs_default', conn_type='hdfs', + host='localhost2', port=8020) mock_get_connections.return_value = [c1, c2] client = HDFSHook().get_conn() self.assertIsInstance(client, snakebite.client.HAClient) diff --git a/tests/hooks/test_docker_hook.py b/tests/hooks/test_docker_hook.py index dd7ed4d44d..f4bc6b4f46 100644 --- a/tests/hooks/test_docker_hook.py +++ b/tests/hooks/test_docker_hook.py @@ -20,8 +20,8 @@ import unittest from airflow import configuration -from airflow import models from airflow.exceptions import AirflowException +from airflow.models.connection import Connection from airflow.utils import db try: @@ -43,7 +43,7 @@ class DockerHookTest(unittest.TestCase): def setUp(self): configuration.load_test_config() db.merge_conn( - models.Connection( + Connection( conn_id='docker_default', conn_type='docker', host='some.docker.registry.com', @@ -52,7 +52,7 @@ def setUp(self): ) ) db.merge_conn( - models.Connection( + Connection( conn_id='docker_with_extras', conn_type='docker', host='some.docker.registry.com', @@ -148,7 +148,7 @@ def test_conn_with_extra_config_passes_parameters(self, _): def test_conn_with_broken_config_missing_username_fails(self, _): db.merge_conn( - models.Connection( + Connection( conn_id='docker_without_username', conn_type='docker', host='some.docker.registry.com', @@ -165,7 +165,7 @@ def test_conn_with_broken_config_missing_username_fails(self, _): def test_conn_with_broken_config_missing_host_fails(self, _): db.merge_conn( - models.Connection( + Connection( conn_id='docker_without_host', conn_type='docker', login='some_user', diff --git a/tests/hooks/test_http_hook.py b/tests/hooks/test_http_hook.py index e64e59fbca..9d11b048a4 100644 --- a/tests/hooks/test_http_hook.py +++ b/tests/hooks/test_http_hook.py @@ -20,9 +20,10 @@ import tenacity -from airflow import configuration, models +from airflow import configuration from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook +from airflow.models.connection import Connection try: from unittest import mock @@ -34,7 +35,7 @@ def get_airflow_connection(conn_id=None): - return models.Connection( + return Connection( conn_id='http_default', conn_type='http', host='test:8080/', @@ -43,7 +44,7 @@ def get_airflow_connection(conn_id=None): def get_airflow_connection_with_port(conn_id=None): - return models.Connection( + return Connection( conn_id='http_default', conn_type='http', host='test.com', @@ -247,8 +248,8 @@ def run_and_return(session, prepped_request, extra_options, **kwargs): @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection') def test_http_connection(self, mock_get_connection): - c = models.Connection(conn_id='http_default', conn_type='http', - host='localhost', schema='http') + c = Connection(conn_id='http_default', conn_type='http', + host='localhost', schema='http') mock_get_connection.return_value = c hook = HttpHook() hook.get_conn({}) @@ -256,8 +257,8 @@ def test_http_connection(self, mock_get_connection): @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection') def test_https_connection(self, mock_get_connection): - c = models.Connection(conn_id='http_default', conn_type='http', - host='localhost', schema='https') + c = Connection(conn_id='http_default', conn_type='http', + host='localhost', schema='https') mock_get_connection.return_value = c hook = HttpHook() hook.get_conn({}) @@ -265,8 +266,8 @@ def test_https_connection(self, mock_get_connection): @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection') def test_host_encoded_http_connection(self, mock_get_connection): - c = models.Connection(conn_id='http_default', conn_type='http', - host='http://localhost') + c = Connection(conn_id='http_default', conn_type='http', + host='http://localhost') mock_get_connection.return_value = c hook = HttpHook() hook.get_conn({}) @@ -274,8 +275,8 @@ def test_host_encoded_http_connection(self, mock_get_connection): @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection') def test_host_encoded_https_connection(self, mock_get_connection): - c = models.Connection(conn_id='http_default', conn_type='http', - host='https://localhost') + c = Connection(conn_id='http_default', conn_type='http', + host='https://localhost') mock_get_connection.return_value = c hook = HttpHook() hook.get_conn({}) diff --git a/tests/hooks/test_mysql_hook.py b/tests/hooks/test_mysql_hook.py index 22e1874ed2..ed7855959f 100644 --- a/tests/hooks/test_mysql_hook.py +++ b/tests/hooks/test_mysql_hook.py @@ -24,8 +24,8 @@ import MySQLdb.cursors -from airflow import models from airflow.hooks.mysql_hook import MySqlHook +from airflow.models.connection import Connection SSL_DICT = { 'cert': '/tmp/client-cert.pem', @@ -39,7 +39,7 @@ class TestMySqlHookConn(unittest.TestCase): def setUp(self): super(TestMySqlHookConn, self).setUp() - self.connection = models.Connection( + self.connection = Connection( login='login', password='password', host='host', diff --git a/tests/models.py b/tests/models.py index 41e4e881e7..921f8a52cf 100644 --- a/tests/models.py +++ b/tests/models.py @@ -41,7 +41,6 @@ from airflow import AirflowException, configuration, models, settings from airflow.exceptions import AirflowDagCycleException, AirflowSkipException from airflow.jobs import BackfillJob -from airflow.models import Connection from airflow.models import DAG, TaskInstance as TI from airflow.models import DagModel, DagRun, DagStat from airflow.models import KubeResourceVersion, KubeWorkerIdentifier @@ -49,6 +48,7 @@ from airflow.models import State as ST from airflow.models import XCom from airflow.models import clear_task_instances +from airflow.models.connection import Connection from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py index b58a4523c9..500fcf3d99 100644 --- a/tests/www_rbac/test_views.py +++ b/tests/www_rbac/test_views.py @@ -37,6 +37,7 @@ from airflow import models, settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.models import DAG, DagRun, TaskInstance +from airflow.models.connection import Connection from airflow.operators.dummy_operator import DummyOperator from airflow.settings import Session from airflow.utils import dates, timezone @@ -120,7 +121,7 @@ def setUp(self): } def tearDown(self): - self.clear_table(models.Connection) + self.clear_table(Connection) super(TestConnectionModelView, self).tearDown() def test_create_connection(self): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Refactor: Move Connection out of models.py > ------------------------------------------ > > Key: AIRFLOW-3458 > URL: https://issues.apache.org/jira/browse/AIRFLOW-3458 > Project: Apache Airflow > Issue Type: Task > Components: models > Affects Versions: 1.10.1 > Reporter: Fokko Driesprong > Priority: Major > Fix For: 2.0.0 > > -- This message was sent by Atlassian JIRA (v7.6.3#76005)