Repository: incubator-ariatosca Updated Branches: refs/heads/ARIA-105-integrate-modeling aa01cd4e9 -> dd99f0fbc (forced update) refs/heads/ARIA-122-Create-central-instantiation-module c7e28595d -> 8cd3113fa (forced update)
initial commit Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/8cd3113f Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/8cd3113f Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/8cd3113f Branch: refs/heads/ARIA-122-Create-central-instantiation-module Commit: 8cd3113fa8b8b873ff83921298c78a18eec5e5d7 Parents: c0d76ad Author: max-orlov <[email protected]> Authored: Thu Mar 9 15:22:23 2017 +0200 Committer: max-orlov <[email protected]> Committed: Thu Mar 9 15:41:59 2017 +0200 ---------------------------------------------------------------------- aria/orchestrator/instantiation.py | 173 ++++++++++++++++++++++ aria/storage/modeling/structure.py | 4 +- aria/storage/modeling/template_elements.py | 183 +++++------------------- tests/conftest.py | 3 +- tests/orchestrator/test_instantiation.py | 79 ++++++++++ 5 files changed, 290 insertions(+), 152 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/8cd3113f/aria/orchestrator/instantiation.py ---------------------------------------------------------------------- diff --git a/aria/orchestrator/instantiation.py b/aria/orchestrator/instantiation.py new file mode 100644 index 0000000..8b36257 --- /dev/null +++ b/aria/orchestrator/instantiation.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from aria.parser import validation +from aria.storage.modeling import utils +from aria.utils.collections import deepcopy_with_locators +from aria.storage.modeling import model + + [email protected](instance_cls=model.ServiceInstance) +def instantiate_service(self, instance_cls, context, container): + service_instance = instance_cls() + context.modeling.instance = service_instance + + service_instance.description = deepcopy_with_locators(self.description) + + if self.metadata is not None: + service_instance.metadata = self.metadata.instantiate(context, container) + + for node_template in self.node_templates.itervalues(): + for _ in range(node_template.default_instances): + node = node_template.instantiate(context, container) + service_instance.nodes[node.id] = node + + utils.instantiate_dict(context, self, service_instance.groups, self.group_templates) + utils.instantiate_dict(context, self, service_instance.policies, self.policy_templates) + utils.instantiate_dict(context, self, service_instance.operations, self.operation_templates) + + if self.substitution_template is not None: + service_instance.substitution = self.substitution_template.instantiate(context, + container) + + utils.instantiate_dict(context, self, service_instance.inputs, self.inputs) + utils.instantiate_dict(context, self, service_instance.outputs, self.outputs) + + for name, the_input in context.modeling.inputs.iteritems(): + if name not in service_instance.inputs: + context.validation.report('input "%s" is not supported' % name) + else: + service_instance.inputs[name].value = the_input + + return service_instance + + [email protected](instance_cls=model.Artifact) +def instantiate_artifact(self, instance_cls, context, container): + artifact = instance_cls(self.name, self.type_name, self.source_path) + artifact.description = deepcopy_with_locators(self.description) + artifact.target_path = self.target_path + artifact.repository_url = self.repository_url + artifact.repository_credential = self.repository_credential + utils.instantiate_dict(context, container, artifact.properties, self.properties) + return artifact + + [email protected](instance_cls=model.Capability) +def instantiate_capability(self, instance_cls, context, container): + capability = instance_cls(self.name, self.type_name) + capability.min_occurrences = self.min_occurrences + capability.max_occurrences = self.max_occurrences + utils.instantiate_dict(context, container, capability.properties, self.properties) + return capability + + [email protected](instance_cls=model.Interface) +def instantiate_interface(self, instance_cls, context, container): + interface = instance_cls(self.name, self.type_name) + interface.description = deepcopy_with_locators(self.description) + utils.instantiate_dict(context, container, interface.inputs, self.inputs) + utils.instantiate_dict(context, container, interface.operations, self.operation_templates) + return interface + + [email protected](instance_cls=model.Operation) +def instantiate_operation(self, instance_cls, context, container): + operation = instance_cls(self.name) + operation.description = deepcopy_with_locators(self.description) + operation.implementation = self.implementation + operation.dependencies = self.dependencies + operation.executor = self.executor + operation.max_retries = self.max_retries + operation.retry_interval = self.retry_interval + utils.instantiate_dict(context, container, operation.inputs, self.inputs) + return operation + + [email protected](instance_cls=model.Policy) +def instantiate_policy(self, instance_cls, context, **kwargs): + policy = instance_cls(self.name, self.type_name) + utils.instantiate_dict(context, self, policy.properties, self.properties) + for node_template_name in self.target_node_template_names: + policy.target_node_ids.extend( + context.modeling.instance.get_node_ids(node_template_name)) + for group_template_name in self.target_group_template_names: + policy.target_group_ids.extend( + context.modeling.instance.get_group_ids(group_template_name)) + return policy + + [email protected](instance_cls=model.GroupPolicy) +def instantiate_group_policy(self, instance_cls, context, container): + group_policy = instance_cls(self.name, self.type_name) + group_policy.description = deepcopy_with_locators(self.description) + utils.instantiate_dict(context, container, group_policy.properties, self.properties) + utils.instantiate_dict(context, container, group_policy.triggers, self.triggers) + return group_policy + + [email protected](instance_cls=model.GroupPolicyTrigger) +def instantiate_group_policy_trigger(self, instance_cls, context, container): + group_policy_trigger = instance_cls(self.name, self.implementation) + group_policy_trigger.description = deepcopy_with_locators(self.description) + utils.instantiate_dict(context, container, group_policy_trigger.properties, self.properties) + return group_policy_trigger + + [email protected](instance_cls=model.Mapping) +def instantiate_mapping(self, instance_cls, context, container): + nodes = context.modeling.instance.find_nodes(self.node_template_name) + if len(nodes) == 0: + context.validation.report( + 'mapping "%s" refer to node template "%s" but there are no ' + 'node instances' % (self.mapped_name, + self.node_template_name), + level=validation.Issue.BETWEEN_INSTANCES) + return None + return instance_cls(self.mapped_name, nodes[0].id, self.name) + + [email protected](instance_cls=model.Substitution) +def instantiate_substitution(self, instance_cls, context, container): + substitution = instance_cls(self.node_type_name) + utils.instantiate_dict(context, container, substitution.capabilities, + self.capability_templates) + utils.instantiate_dict(context, container, substitution.requirements, + self.requirement_templates) + return substitution + + [email protected](instance_cls=model.Node) +def instantiate_node(self, instance_cls, context, **kwargs): + node = instance_cls(context, self.type_name, self.name) + utils.instantiate_dict(context, node, node.properties, self.properties) + utils.instantiate_dict(context, node, node.interfaces, self.interface_templates) + utils.instantiate_dict(context, node, node.artifacts, self.artifact_templates) + utils.instantiate_dict(context, node, node.capabilities, self.capability_templates) + return node + + [email protected](instance_cls=model.Group) +def instantiate_node(self, instance_cls, context, **kwargs): + group = instance_cls(context, self.type_name, self.name) + utils.instantiate_dict(context, self, group.properties, self.properties) + utils.instantiate_dict(context, self, group.interfaces, self.interface_templates) + utils.instantiate_dict(context, self, group.policies, self.policy_templates) + for member_node_template_name in self.member_node_template_names: + group.member_node_ids += \ + context.modeling.instance.get_node_ids(member_node_template_name) + for member_group_template_name in self.member_group_template_names: + group.member_group_ids += \ + context.modeling.instance.get_group_ids(member_group_template_name) + return group http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/8cd3113f/aria/storage/modeling/structure.py ---------------------------------------------------------------------- diff --git a/aria/storage/modeling/structure.py b/aria/storage/modeling/structure.py index eacdb44..4ee6951 100644 --- a/aria/storage/modeling/structure.py +++ b/aria/storage/modeling/structure.py @@ -88,8 +88,8 @@ class ModelElementBase(ElementBase): All model elements can be instantiated into :class:`ServiceInstance` elements. """ - - def instantiate(self, context, container): + @classmethod + def instantiate(cls, *args, **kwargs): raise NotImplementedError http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/8cd3113f/aria/storage/modeling/template_elements.py ---------------------------------------------------------------------- diff --git a/aria/storage/modeling/template_elements.py b/aria/storage/modeling/template_elements.py index 4212b15..15b93af 100644 --- a/aria/storage/modeling/template_elements.py +++ b/aria/storage/modeling/template_elements.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from types import FunctionType +from functools import partial +from types import FunctionType, MethodType from sqlalchemy import ( Column, @@ -41,7 +42,26 @@ from . import ( # region Element templates -class ServiceTemplateBase(structure.ModelMixin): +class TemplateBase(structure.ModelMixin): + + @classmethod + def instantiates(cls, func=None, instance_cls=None, override=False): + if not override: + try: + cls.instantiate() + except NotImplementedError: + pass + else: + # TODO: raise proper exception + raise BaseException("instantiation method was already registered for {0}" + .format(cls.__name__)) + if func is None: + return partial(cls.instantiates, instance_cls=instance_cls, override=override) + + cls.instantiate = MethodType(partial(func, instance_cls=instance_cls), None, cls) + + +class ServiceTemplateBase(TemplateBase): __tablename__ = 'service_template' @@ -97,39 +117,6 @@ class ServiceTemplateBase(structure.ModelMixin): ('outputs', formatting.as_raw_dict(self.outputs)), ('operation_templates', formatting.as_raw_list(self.operation_templates)))) - def instantiate(self, context, container): - service_instance = instance_elements.ServiceInstanceBase() - context.modeling.instance = service_instance - - service_instance.description = deepcopy_with_locators(self.description) - - if self.metadata is not None: - service_instance.metadata = self.metadata.instantiate(context, container) - - for node_template in self.node_templates.itervalues(): - for _ in range(node_template.default_instances): - node = node_template.instantiate(context, container) - service_instance.nodes[node.id] = node - - utils.instantiate_dict(context, self, service_instance.groups, self.group_templates) - utils.instantiate_dict(context, self, service_instance.policies, self.policy_templates) - utils.instantiate_dict(context, self, service_instance.operations, self.operation_templates) - - if self.substitution_template is not None: - service_instance.substitution = self.substitution_template.instantiate(context, - container) - - utils.instantiate_dict(context, self, service_instance.inputs, self.inputs) - utils.instantiate_dict(context, self, service_instance.outputs, self.outputs) - - for name, the_input in context.modeling.inputs.iteritems(): - if name not in service_instance.inputs: - context.validation.report('input "%s" is not supported' % name) - else: - service_instance.inputs[name].value = the_input - - return service_instance - def validate(self, context): if self.metadata is not None: self.metadata.validate(context) @@ -172,7 +159,7 @@ class ServiceTemplateBase(structure.ModelMixin): utils.dump_dict_values(context, self.operation_templates, 'Operation templates') -class InterfaceTemplateBase(structure.ModelMixin): +class InterfaceTemplateBase(TemplateBase): __tablename__ = 'interface_template' __private_fields__ = ['node_template_fk', @@ -225,13 +212,6 @@ class InterfaceTemplateBase(structure.ModelMixin): # TODO fix self.properties reference ('operation_templates', formatting.as_raw_list(self.operation_templates)))) - def instantiate(self, context, container): - interface = instance_elements.InterfaceBase(self.name, self.type_name) - interface.description = deepcopy_with_locators(self.description) - utils.instantiate_dict(context, container, interface.inputs, self.inputs) - utils.instantiate_dict(context, container, interface.operations, self.operation_templates) - return interface - def validate(self, context): if self.type_name: if context.modeling.interface_types.get_descendant(self.type_name) is None: @@ -256,7 +236,7 @@ class InterfaceTemplateBase(structure.ModelMixin): utils.dump_dict_values(context, self.operation_templates, 'Operation templates') -class OperationTemplateBase(structure.ModelMixin): +class OperationTemplateBase(TemplateBase): __tablename__ = 'operation_template' __private_fields__ = ['service_template_fk', @@ -315,17 +295,6 @@ class OperationTemplateBase(structure.ModelMixin): ('retry_interval', self.retry_interval), ('inputs', formatting.as_raw_dict(self.inputs)))) - def instantiate(self, context, container): - operation = instance_elements.OperationBase(self.name) - operation.description = deepcopy_with_locators(self.description) - operation.implementation = self.implementation - operation.dependencies = self.dependencies - operation.executor = self.executor - operation.max_retries = self.max_retries - operation.retry_interval = self.retry_interval - utils.instantiate_dict(context, container, operation.inputs, self.inputs) - return operation - def validate(self, context): utils.validate_dict_values(context, self.inputs) @@ -351,7 +320,7 @@ class OperationTemplateBase(structure.ModelMixin): dump_parameters(context, self.inputs, 'Inputs') -class ArtifactTemplateBase(structure.ModelMixin): +class ArtifactTemplateBase(TemplateBase): """ A file associated with a :class:`NodeTemplate`. @@ -411,15 +380,6 @@ class ArtifactTemplateBase(structure.ModelMixin): ('repository_credential', formatting.as_agnostic(self.repository_credential)), ('properties', formatting.as_raw_dict(self.properties.iteritems())))) - def instantiate(self, context, container): - artifact = instance_elements.ArtifactBase(self.name, self.type_name, self.source_path) - artifact.description = deepcopy_with_locators(self.description) - artifact.target_path = self.target_path - artifact.repository_url = self.repository_url - artifact.repository_credential = self.repository_credential - utils.instantiate_dict(context, container, artifact.properties, self.properties) - return artifact - def validate(self, context): if context.modeling.artifact_types.get_descendant(self.type_name) is None: context.validation.report('artifact "%s" has an unknown type: %s' @@ -448,7 +408,7 @@ class ArtifactTemplateBase(structure.ModelMixin): dump_parameters(context, self.properties) -class PolicyTemplateBase(structure.ModelMixin): +class PolicyTemplateBase(TemplateBase): """ Policies can be applied to zero or more :class:`NodeTemplate` or :class:`GroupTemplate` instances. @@ -513,17 +473,6 @@ class PolicyTemplateBase(structure.ModelMixin): ('target_node_template_names', self.target_node_template_names), ('target_group_template_names', self.target_group_template_names))) - def instantiate(self, context, *args, **kwargs): - policy = instance_elements.PolicyBase(self.name, self.type_name) - utils.instantiate_dict(context, self, policy.properties, self.properties) - for node_template_name in self.target_node_template_names: - policy.target_node_ids.extend( - context.modeling.instance.get_node_ids(node_template_name)) - for group_template_name in self.target_group_template_names: - policy.target_group_ids.extend( - context.modeling.instance.get_group_ids(group_template_name)) - return policy - def validate(self, context): if context.modeling.policy_types.get_descendant(self.type_name) is None: context.validation.report('policy template "%s" has an unknown type: %s' @@ -550,7 +499,7 @@ class PolicyTemplateBase(structure.ModelMixin): (str(context.style.node(v)) for v in self.target_group_template_names))) -class GroupPolicyTemplateBase(structure.ModelMixin): +class GroupPolicyTemplateBase(TemplateBase): """ Policies applied to groups. @@ -594,13 +543,6 @@ class GroupPolicyTemplateBase(structure.ModelMixin): ('properties', formatting.as_raw_dict(self.properties)), ('triggers', formatting.as_raw_list(self.triggers)))) - def instantiate(self, context, container): - group_policy = instance_elements.GroupPolicyBase(self.name, self.type_name) - group_policy.description = deepcopy_with_locators(self.description) - utils.instantiate_dict(context, container, group_policy.properties, self.properties) - utils.instantiate_dict(context, container, group_policy.triggers, self.triggers) - return group_policy - def validate(self, context): if context.modeling.policy_types.get_descendant(self.type_name) is None: context.validation.report('group policy "%s" has an unknown type: %s' @@ -624,7 +566,7 @@ class GroupPolicyTemplateBase(structure.ModelMixin): utils.dump_dict_values(context, self.triggers, 'Triggers') -class GroupPolicyTriggerTemplateBase(structure.ModelMixin): +class GroupPolicyTriggerTemplateBase(TemplateBase): """ Triggers for :class:`GroupPolicyTemplate`. @@ -674,14 +616,6 @@ class GroupPolicyTriggerTemplateBase(structure.ModelMixin): ('implementation', self.implementation), ('properties', formatting.as_raw_dict(self.properties)))) - def instantiate(self, context, container): - group_policy_trigger = instance_elements.GroupPolicyTriggerBase(self.name, - self.implementation) - group_policy_trigger.description = deepcopy_with_locators(self.description) - utils.instantiate_dict(context, container, group_policy_trigger.properties, - self.properties) - return group_policy_trigger - def validate(self, context): utils.validate_dict_values(context, self.properties) @@ -697,7 +631,7 @@ class GroupPolicyTriggerTemplateBase(structure.ModelMixin): dump_parameters(context, self.properties) -class MappingTemplateBase(structure.ModelMixin): +class MappingTemplateBase(TemplateBase): """ Used by :class:`SubstitutionTemplate` to map a capability or a requirement to a node. @@ -719,17 +653,6 @@ class MappingTemplateBase(structure.ModelMixin): ('node_template_name', self.node_template_name), ('name', self.name))) - def instantiate(self, context, *args, **kwargs): - nodes = context.modeling.instance.find_nodes(self.node_template_name) - if len(nodes) == 0: - context.validation.report( - 'mapping "%s" refer to node template "%s" but there are no ' - 'node instances' % (self.mapped_name, - self.node_template_name), - level=validation.Issue.BETWEEN_INSTANCES) - return None - return instance_elements.MappingBase(self.mapped_name, nodes[0].id, self.name) - def validate(self, context): if self.node_template_name not in context.modeling.model.node_templates: context.validation.report('mapping "%s" refers to an unknown node template: %s' @@ -744,7 +667,7 @@ class MappingTemplateBase(structure.ModelMixin): context.style.node(self.name))) -class SubstitutionTemplateBase(structure.ModelMixin): +class SubstitutionTemplateBase(TemplateBase): """ Used to substitute a single node for the entire deployment. @@ -780,13 +703,6 @@ class SubstitutionTemplateBase(structure.ModelMixin): ('capability_templates', formatting.as_raw_list(self.capability_templates)), ('requirement_templates', formatting.as_raw_list(self.requirement_templates)))) - def instantiate(self, context, container): - substitution = instance_elements.SubstitutionBase(self.node_type_name) - utils.instantiate_dict(context, container, substitution.capabilities, - self.capability_templates) - utils.instantiate_dict(context, container, substitution.requirements, - self.requirement_templates) - return substitution def validate(self, context): if context.modeling.node_types.get_descendant(self.node_type_name) is None: @@ -815,7 +731,7 @@ class SubstitutionTemplateBase(structure.ModelMixin): # region Node templates -class NodeTemplateBase(structure.ModelMixin): +class NodeTemplateBase(TemplateBase): __tablename__ = 'node_template' __private_fields__ = ['service_template_fk', @@ -892,14 +808,6 @@ class NodeTemplateBase(structure.ModelMixin): ('capability_templates', formatting.as_raw_list(self.capability_templates)), ('requirement_templates', formatting.as_raw_list(self.requirement_templates)))) - def instantiate(self, context, *args, **kwargs): - node = instance_elements.NodeBase(context, self.type_name, self.name) - utils.instantiate_dict(context, node, node.properties, self.properties) - utils.instantiate_dict(context, node, node.interfaces, self.interface_templates) - utils.instantiate_dict(context, node, node.artifacts, self.artifact_templates) - utils.instantiate_dict(context, node, node.capabilities, self.capability_templates) - return node - def validate(self, context): if context.modeling.node_types.get_descendant(self.type_name) is None: context.validation.report('node template "%s" has an unknown type: %s' @@ -939,7 +847,7 @@ class NodeTemplateBase(structure.ModelMixin): utils.dump_list_values(context, self.requirement_templates, 'Requirement templates') -class GroupTemplateBase(structure.ModelMixin): +class GroupTemplateBase(TemplateBase): """ A template for creating zero or more :class:`Group` instances. @@ -1001,19 +909,6 @@ class GroupTemplateBase(structure.ModelMixin): ('member_node_template_names', self.member_node_template_names), ('member_group_template_names', self.member_group_template_names1))) - def instantiate(self, context, *args, **kwargs): - group = instance_elements.GroupBase(context, self.type_name, self.name) - utils.instantiate_dict(context, self, group.properties, self.properties) - utils.instantiate_dict(context, self, group.interfaces, self.interface_templates) - utils.instantiate_dict(context, self, group.policies, self.policy_templates) - for member_node_template_name in self.member_node_template_names: - group.member_node_ids += \ - context.modeling.instance.get_node_ids(member_node_template_name) - for member_group_template_name in self.member_group_template_names: - group.member_group_ids += \ - context.modeling.instance.get_group_ids(member_group_template_name) - return group - def validate(self, context): if context.modeling.group_types.get_descendant(self.type_name) is None: context.validation.report('group template "%s" has an unknown type: %s' @@ -1048,7 +943,7 @@ class GroupTemplateBase(structure.ModelMixin): # region Relationship templates -class RequirementTemplateBase(structure.ModelMixin): +class RequirementTemplateBase(TemplateBase): """ A requirement for a :class:`NodeTemplate`. During instantiation will be matched with a capability of another @@ -1094,9 +989,6 @@ class RequirementTemplateBase(structure.ModelMixin): return cls.many_to_one_relationship('node_template') # endregion - def instantiate(self, context, container): - raise NotImplementedError - def find_target(self, context, source_node_template): # We might already have a specific node template, so we'll just verify it if self.target_node_template_name is not None: @@ -1214,7 +1106,7 @@ class RequirementTemplateBase(structure.ModelMixin): self.relationship_template.dump(context) -class CapabilityTemplateBase(structure.ModelMixin): +class CapabilityTemplateBase(TemplateBase): """ A capability of a :class:`NodeTemplate`. Nodes expose zero or more capabilities that can be matched with :class:`Requirement` instances of other nodes. @@ -1299,13 +1191,6 @@ class CapabilityTemplateBase(structure.ModelMixin): ('valid_source_node_type_names', self.valid_source_node_type_names), ('properties', formatting.as_raw_dict(self.properties)))) - def instantiate(self, context, container): - capability = instance_elements.CapabilityBase(self.name, self.type_name) - capability.min_occurrences = self.min_occurrences - capability.max_occurrences = self.max_occurrences - utils.instantiate_dict(context, container, capability.properties, self.properties) - return capability - def validate(self, context): if context.modeling.capability_types.get_descendant(self.type_name) is None: context.validation.report('capability "%s" refers to an unknown type: %s' http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/8cd3113f/tests/conftest.py ---------------------------------------------------------------------- diff --git a/tests/conftest.py b/tests/conftest.py index c501eeb..9338bd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,8 @@ import aria @pytest.fixture(scope='session', autouse=True) def install_aria_extensions(): - aria.install_aria_extensions() + pass + # aria.install_aria_extensions() @pytest.fixture(autouse=True) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/8cd3113f/tests/orchestrator/test_instantiation.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/test_instantiation.py b/tests/orchestrator/test_instantiation.py new file mode 100644 index 0000000..e159e04 --- /dev/null +++ b/tests/orchestrator/test_instantiation.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from aria.storage.modeling import model_base, structure + + [email protected](autouse=True) +def cleanup_instantiator(request): + instantiator = MockTemplate.instantiate + + def clear_instantiator(): + MockTemplate.instantiate = instantiator + request.addfinalizer(clear_instantiator) + + +class MockTemplate(model_base.template_elements.TemplateBase): + def __init__(self, template_name): + self.template_name = template_name + + +class MockInstnace(structure.ModelMixin): + def __init__(self, instance_name): + self.instance_name = instance_name + + +def test_base_instantiation(): + name = 'my_name' + + @MockTemplate.instantiates(instance_cls=MockInstnace) + def initiator(self, instance_cls): + return instance_cls(self.template_name) + + mock_template = MockTemplate(name) + mock_instance = mock_template.instantiate() + + assert mock_instance.instance_name == mock_template.template_name == name + + +def test_reinstantiate(): + + name = 'my_name' + + @MockTemplate.instantiates(instance_cls=MockInstnace) + def initiator(self, instance_cls): + return instance_cls(self.template_name) + + mock_template = MockTemplate(name) + mock_instance = mock_template.instantiate() + assert mock_instance.instance_name == mock_template.template_name == name + + def new_initiator(self, instance_cls): + return instance_cls('new_{0}'.format(self.template_name)) + + with pytest.raises(BaseException): + MockTemplate.instantiates(func=new_initiator, instance_cls=MockInstnace) + + mock_template = MockTemplate(name) + mock_instance = mock_template.instantiate() + assert mock_instance.instance_name == mock_template.template_name == name + + MockTemplate.instantiates(func=new_initiator, instance_cls=MockInstnace, override=True) + mock_template = MockTemplate(name) + mock_instance = mock_template.instantiate() + assert mock_template.template_name == name + assert mock_instance.instance_name == 'new_{0}'.format(name)
