potiuk commented on a change in pull request #8024: [AIRFLOW-4529] Add support
for Azure Batch Service
URL: https://github.com/apache/airflow/pull/8024#discussion_r400620446
##
File path: airflow/providers/microsoft/azure/hooks/azure_batch.py
##
@@ -0,0 +1,348 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import time
+from datetime import timedelta
+from typing import Optional
+
+from azure.batch import BatchServiceClient, batch_auth, models as batch_models
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.utils import timezone
+
+
+class AzureBatchHook(BaseHook):
+"""
+Hook for Azure Batch APIs
+"""
+
+def __init__(self, azure_batch_conn_id='azure_batch_default'):
+super().__init__()
+self.conn_id = azure_batch_conn_id
+self.connection = self.get_conn()
+self.extra = self._connection().extra_dejson
+
+def _connection(self):
+"""
+Get connected to azure batch service
+"""
+conn = self.get_connection(self.conn_id)
+return conn
+
+def get_conn(self):
+"""
+Get the batch client connection
+:return: Azure batch client
+"""
+conn = self._connection()
+
+def _get_required_param(name):
+"""Extract required parameter from extra JSON, raise exception if
not found"""
+value = conn.extra_dejson.get(name)
+if not value:
+raise AirflowException(
+'Extra connection option is missing required parameter:
`{}`'.
+format(name))
+return value
+batch_account_name = _get_required_param('account_name')
+batch_account_key = _get_required_param('account_key')
+batch_account_url = _get_required_param('account_url')
+credentials = batch_auth.SharedKeyCredentials(batch_account_name,
+ batch_account_key)
+batch_client = BatchServiceClient(
+credentials,
+batch_url=batch_account_url)
+return batch_client
+
+def configure_pool(self,
+ pool_id: str,
+ vm_size: str,
+ display_name: Optional[str] = None,
+ target_dedicated_nodes: Optional[int] = None,
+ use_latest_image_and_sku: bool = False,
+ vm_publisher: Optional[str] = None,
+ vm_offer: Optional[str] = None,
+ sku_starts_with: Optional[str] = None,
+ **kwargs
+ ):
+"""
+Configures a pool
+
+:param pool_id: A string that uniquely identifies the Pool within the
Account
+:type pool_id: str
+
+:param vm_size: The size of virtual machines in the Pool.
+:type vm_size: str
+
+:param display_name: The display name for the Pool
+:type display_name: str
+
+:param target_dedicated_nodes: The desired number of dedicated Compute
Nodes in the Pool.
+:type target_dedicated_nodes: Optional[int]
+
+:param use_latest_image_and_sku: Whether to use the latest verified vm
image and sku
+:type use_latest_image_and_sku: bool
+
+:param vm_publisher: The publisher of the Azure Virtual Machines
Marketplace Image.
+For example, Canonical or MicrosoftWindowsServer.
+:type vm_publisher: Optional[str]
+
+:param vm_offer: The offer type of the Azure Virtual Machines
Marketplace Image.
+For example, UbuntuServer or WindowsServer.
+:type vm_offer: Optional[str]
+
+:param sku_starts_with: The start name of the sku to search
+:type sku_starts_with: Optional[str]
+"""
+if use_latest_image_and_sku:
+self.log.info('Using latest verified virtual machine image with
node agent sku')
+sku_to_use, image_ref_to_use = \
+
self._get_latest_verified_image_vm_and_sku(publisher=vm_publisher,
+ offer=vm_offer,
+