This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 72d0b15304e Add missing test for apache cassandra hooks (#43951)
72d0b15304e is described below
commit 72d0b15304e60352ee420e7a4faa370d18bf21a1
Author: yangyulely <[email protected]>
AuthorDate: Wed Nov 13 20:22:04 2024 +0800
Add missing test for apache cassandra hooks (#43951)
---
providers/tests/apache/cassandra/hooks/__init__.py | 17 ++
.../tests/apache/cassandra/hooks/test_cassandra.py | 188 +++++++++++++++++++++
tests/always/test_project_structure.py | 1 -
3 files changed, 205 insertions(+), 1 deletion(-)
diff --git a/providers/tests/apache/cassandra/hooks/__init__.py
b/providers/tests/apache/cassandra/hooks/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/providers/tests/apache/cassandra/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/tests/apache/cassandra/hooks/test_cassandra.py
b/providers/tests/apache/cassandra/hooks/test_cassandra.py
new file mode 100644
index 00000000000..6a373eabcd9
--- /dev/null
+++ b/providers/tests/apache/cassandra/hooks/test_cassandra.py
@@ -0,0 +1,188 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+from cassandra.auth import PlainTextAuthProvider
+from cassandra.cluster import Cluster, DCAwareRoundRobinPolicy,
TokenAwarePolicy
+
+from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
+
+
[email protected]
+@patch("airflow.providers.apache.cassandra.hooks.cassandra.BaseHook.get_connection")
+def mock_cassandra_hook(mock_get_connection):
+ conn = Mock()
+ conn.host = "127.0.0.1,127.0.0.2"
+ conn.port = "9042"
+ conn.login = "user"
+ conn.password = "pass"
+ conn.schema = "test_keyspace"
+ conn.extra_dejson = {
+ "load_balancing_policy": "DCAwareRoundRobinPolicy",
+ "load_balancing_policy_args": {
+ "local_dc": "dc1",
+ "used_hosts_per_remote_dc": 2,
+ },
+ "cql_version": "3.4.4",
+ "ssl_options": {"ca_certs": "/path/to/certs"},
+ "protocol_version": 4,
+ }
+
+ mock_get_connection.return_value = conn
+ return CassandraHook("test_conn_id")
+
+
+class TestCassandraHook:
+ def test_init_with_valid_connection(self, mock_cassandra_hook):
+ assert isinstance(mock_cassandra_hook.cluster, Cluster)
+ assert mock_cassandra_hook.cluster.contact_points == ["127.0.0.1",
"127.0.0.2"]
+ assert mock_cassandra_hook.cluster.port == 9042
+ assert isinstance(mock_cassandra_hook.cluster.auth_provider,
PlainTextAuthProvider)
+ assert (
+
mock_cassandra_hook.cluster.load_balancing_policy.__class__.__name__ ==
"DCAwareRoundRobinPolicy"
+ )
+ assert mock_cassandra_hook.cluster.cql_version == "3.4.4"
+ assert mock_cassandra_hook.cluster.ssl_options == {"ca_certs":
"/path/to/certs"}
+ assert mock_cassandra_hook.cluster.protocol_version == 4
+ assert mock_cassandra_hook.keyspace == "test_keyspace"
+
+ def test_get_conn_session_exist(self, mock_cassandra_hook):
+ mock_cassandra_hook.session = Mock()
+ mock_cassandra_hook.session.is_shutdown = False
+ hook_session = mock_cassandra_hook.get_conn()
+
+ assert isinstance(hook_session, Mock)
+ assert not hook_session.is_shutdown
+
+ def test_get_conn_session_not_exist(self, mock_cassandra_hook):
+ mock_cassandra_hook.cluster.connect = Mock()
+ hook_session = mock_cassandra_hook.get_conn()
+ assert isinstance(hook_session, Mock)
+
+ def test_get_cluster(self, mock_cassandra_hook):
+ cluster = mock_cassandra_hook.get_cluster()
+ assert cluster.contact_points == ["127.0.0.1", "127.0.0.2"]
+ assert cluster.port == 9042
+ assert isinstance(cluster.auth_provider, PlainTextAuthProvider)
+ assert cluster.load_balancing_policy.__class__.__name__ ==
"DCAwareRoundRobinPolicy"
+ assert cluster.cql_version == "3.4.4"
+ assert cluster.ssl_options == {"ca_certs": "/path/to/certs"}
+ assert cluster.protocol_version == 4
+ assert mock_cassandra_hook.keyspace == "test_keyspace"
+
+ def test_shutdown_cluster(self, mock_cassandra_hook):
+ mock_cassandra_hook.cluster = Mock()
+ mock_cassandra_hook.cluster.is_shutdown = False
+ mock_cassandra_hook.cluster.shutdown.return_value = None
+ mock_cassandra_hook.shutdown_cluster()
+
+ assert not mock_cassandra_hook.cluster.is_shutdown
+ mock_cassandra_hook.cluster.shutdown.assert_called_once()
+
+ def test_get_lb_policy_dc_aware_round_robin_policy(self,
mock_cassandra_hook):
+ policy_args = {"local_dc": "dc1", "used_hosts_per_remote_dc": 2}
+ lb_policy =
mock_cassandra_hook.get_lb_policy("DCAwareRoundRobinPolicy", policy_args)
+
+ assert isinstance(lb_policy, DCAwareRoundRobinPolicy)
+ assert lb_policy.local_dc == "dc1"
+ assert lb_policy.used_hosts_per_remote_dc == 2
+
+ def
test_get_lb_policy_token_aware_policy_dc_aware_round_robin_policy(self,
mock_cassandra_hook):
+ policy_args = {
+ "child_load_balancing_policy": "DCAwareRoundRobinPolicy",
+ "child_load_balancing_policy_args": {
+ "local_dc": "dc1",
+ "used_hosts_per_remote_dc": 3,
+ },
+ }
+ lb_policy = mock_cassandra_hook.get_lb_policy("TokenAwarePolicy",
policy_args)
+
+ assert isinstance(lb_policy, TokenAwarePolicy)
+ assert isinstance(lb_policy._child_policy, DCAwareRoundRobinPolicy)
+ assert lb_policy._child_policy.local_dc == "dc1"
+ assert lb_policy._child_policy.used_hosts_per_remote_dc == 3
+
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook.get_conn")
+ def test_table_exists(self, mock_get_conn, mock_cassandra_hook):
+ mock_cluster_metadata = Mock()
+ mock_cluster_metadata.keyspaces = {"test_keyspace":
Mock(tables={"test_table": None})}
+ mock_get_conn.return_value.cluster.metadata = mock_cluster_metadata
+
+ result = mock_cassandra_hook.table_exists("test_keyspace.test_table")
+ assert result
+ mock_get_conn.assert_called_once()
+
+ def test_sanitize_input_valid(self, mock_cassandra_hook):
+ input_string = "valid_table_name"
+ sanitized_string = mock_cassandra_hook._sanitize_input(input_string)
+ assert sanitized_string == input_string
+
+ def test_sanitize_input_invalid(self, mock_cassandra_hook):
+ with pytest.raises(ValueError, match=r"Invalid input:
invalid_table_name_with_%_characters"):
+
mock_cassandra_hook._sanitize_input("invalid_table_name_with_%_characters")
+
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook.get_conn")
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook._sanitize_input")
+ def test_record_exists_table(self, mock_sanitize_input, mock_get_conn,
mock_cassandra_hook):
+ table = "test_table"
+ keys = {"key1": "value1", "key2": "value2"}
+ mock_sanitize_input.return_value = table
+ mock_get_conn.return_value.execute.return_value.one = Mock()
+ result = mock_cassandra_hook.record_exists(table, keys)
+ assert result
+ mock_sanitize_input.assert_called_with(table)
+ assert mock_sanitize_input.call_count == 2
+ mock_get_conn.return_value.execute.assert_called_once_with(
+ "SELECT * FROM test_table.test_table WHERE key1=%(key1)s AND
key2=%(key2)s",
+ {"key1": "value1", "key2": "value2"},
+ )
+
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook.get_conn")
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook._sanitize_input")
+ def test_record_exists_keyspace_table(self, mock_sanitize_input,
mock_get_conn, mock_cassandra_hook):
+ table = "test_keyspace.test_table"
+ keys = {"key1": "value1", "key2": "value2"}
+ mock_sanitize_input.return_value = "test_table"
+ mock_get_conn.return_value.execute.return_value.one = Mock()
+ result = mock_cassandra_hook.record_exists(table, keys)
+ assert result
+ mock_sanitize_input.assert_called_with("test_table")
+ assert mock_sanitize_input.call_count == 3
+ mock_get_conn.return_value.execute.assert_called_once_with(
+ "SELECT * FROM test_table.test_table WHERE key1=%(key1)s AND
key2=%(key2)s",
+ {"key1": "value1", "key2": "value2"},
+ )
+
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook.get_conn")
+
@patch("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook._sanitize_input")
+ def test_record_exists_exception(self, mock_sanitize_input, mock_get_conn,
mock_cassandra_hook):
+ table = "test_keyspace.test_table"
+ keys = {"key1": "value1", "key2": "value2"}
+ mock_sanitize_input.return_value = "test_table"
+ mock_get_conn.return_value.execute.side_effect = Exception("Test
exception")
+ result = mock_cassandra_hook.record_exists(table, keys)
+ assert not result
+ mock_sanitize_input.assert_called_with("test_table")
+ assert mock_sanitize_input.call_count == 3
+ mock_get_conn.return_value.execute.assert_called_once_with(
+ "SELECT * FROM test_table.test_table WHERE key1=%(key1)s AND
key2=%(key2)s",
+ {"key1": "value1", "key2": "value2"},
+ )
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 78534accd18..8651bd968c2 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -75,7 +75,6 @@ class TestProjectStructure:
"providers/tests/amazon/aws/utils/test_rds.py",
"providers/tests/amazon/aws/utils/test_sagemaker.py",
"providers/tests/amazon/aws/waiters/test_base_waiter.py",
- "providers/tests/apache/cassandra/hooks/test_cassandra.py",
"providers/tests/apache/drill/operators/test_drill.py",
"providers/tests/apache/druid/operators/test_druid_check.py",
"providers/tests/apache/hdfs/hooks/test_hdfs.py",