http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/models.py ---------------------------------------------------------------------- diff --git a/aria/storage/models.py b/aria/storage/models.py index d24ad75..6302e66 100644 --- a/aria/storage/models.py +++ b/aria/storage/models.py @@ -36,16 +36,30 @@ classes: * ProviderContext - provider context implementation model. * Plugin - plugin implementation model. """ - +from collections import namedtuple from datetime import datetime -from types import NoneType -from .structures import Field, IterPointerField, Model, uuid_generator, PointerField +from sqlalchemy.ext.declarative.base import declared_attr + +from .structures import ( + SQLModelBase, + Column, + Integer, + Text, + DateTime, + Boolean, + Enum, + String, + Float, + List, + Dict, + foreign_key, + one_to_many_relationship, + relationship_to_self, + orm) __all__ = ( - 'Model', 'Blueprint', - 'Snapshot', 'Deployment', 'DeploymentUpdateStep', 'DeploymentUpdate', @@ -59,66 +73,192 @@ __all__ = ( 'Plugin', ) -# todo: sort this, maybe move from mgr or move from aria??? -ACTION_TYPES = () -ENTITY_TYPES = () + +#pylint: disable=no-self-argument -class Blueprint(Model): +class Blueprint(SQLModelBase): """ - A Model which represents a blueprint + Blueprint model representation. """ - plan = Field(type=dict) - id = Field(type=basestring, default=uuid_generator) - description = Field(type=(basestring, NoneType)) - created_at = Field(type=datetime) - updated_at = Field(type=datetime) - main_file_name = Field(type=basestring) + __tablename__ = 'blueprints' + name = Column(Text, index=True) + created_at = Column(DateTime, nullable=False, index=True) + main_file_name = Column(Text, nullable=False) + plan = Column(Dict, nullable=False) + updated_at = Column(DateTime) + description = Column(Text) -class Snapshot(Model): + +class Deployment(SQLModelBase): """ - A Model which represents a snapshot + Deployment model representation. """ - CREATED = 'created' + __tablename__ = 'deployments' + + _private_fields = ['blueprint_id'] + + blueprint_id = foreign_key(Blueprint.id) + + name = Column(Text, index=True) + created_at = Column(DateTime, nullable=False, index=True) + description = Column(Text) + inputs = Column(Dict) + groups = Column(Dict) + permalink = Column(Text) + policy_triggers = Column(Dict) + policy_types = Column(Dict) + outputs = Column(Dict) + scaling_groups = Column(Dict) + updated_at = Column(DateTime) + workflows = Column(Dict) + + @declared_attr + def blueprint(cls): + return one_to_many_relationship(cls, Blueprint, cls.blueprint_id) + + +class Execution(SQLModelBase): + """ + Execution model representation. + """ + __tablename__ = 'executions' + + TERMINATED = 'terminated' FAILED = 'failed' - CREATING = 'creating' - UPLOADED = 'uploaded' - END_STATES = [CREATED, FAILED, UPLOADED] + CANCELLED = 'cancelled' + PENDING = 'pending' + STARTED = 'started' + CANCELLING = 'cancelling' + FORCE_CANCELLING = 'force_cancelling' - id = Field(type=basestring, default=uuid_generator) - created_at = Field(type=datetime) - status = Field(type=basestring) - error = Field(type=basestring, default=None) + STATES = [TERMINATED, FAILED, CANCELLED, PENDING, STARTED, CANCELLING, FORCE_CANCELLING] + END_STATES = [TERMINATED, FAILED, CANCELLED] + ACTIVE_STATES = [state for state in STATES if state not in END_STATES] + VALID_TRANSITIONS = { + PENDING: [STARTED, CANCELLED], + STARTED: END_STATES + [CANCELLING], + CANCELLING: END_STATES + } -class Deployment(Model): + @orm.validates('status') + def validate_status(self, key, value): + """Validation function that verifies execution status transitions are OK""" + try: + current_status = getattr(self, key) + except AttributeError: + return + valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, []) + if all([current_status is not None, + current_status != value, + value not in valid_transitions]): + raise ValueError('Cannot change execution status from {current} to {new}'.format( + current=current_status, + new=value)) + return value + + deployment_id = foreign_key(Deployment.id) + blueprint_id = foreign_key(Blueprint.id) + _private_fields = ['deployment_id', 'blueprint_id'] + + created_at = Column(DateTime, index=True) + started_at = Column(DateTime, nullable=True, index=True) + ended_at = Column(DateTime, nullable=True, index=True) + error = Column(Text, nullable=True) + is_system_workflow = Column(Boolean, nullable=False, default=False) + parameters = Column(Dict) + status = Column(Enum(*STATES, name='execution_status'), default=PENDING) + workflow_name = Column(Text, nullable=False) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + @declared_attr + def blueprint(cls): + return one_to_many_relationship(cls, Blueprint, cls.blueprint_id) + + def __str__(self): + return '<{0} id=`{1}` (status={2})>'.format( + self.__class__.__name__, + self.id, + self.status + ) + + +class DeploymentUpdate(SQLModelBase): """ - A Model which represents a deployment + Deployment update model representation. """ - id = Field(type=basestring, default=uuid_generator) - description = Field(type=(basestring, NoneType)) - created_at = Field(type=datetime) - updated_at = Field(type=datetime) - blueprint_id = Field(type=basestring) - workflows = Field(type=dict) - inputs = Field(type=dict, default=lambda: {}) - policy_types = Field(type=dict, default=lambda: {}) - policy_triggers = Field(type=dict, default=lambda: {}) - groups = Field(type=dict, default=lambda: {}) - outputs = Field(type=dict, default=lambda: {}) - scaling_groups = Field(type=dict, default=lambda: {}) - - -class DeploymentUpdateStep(Model): + __tablename__ = 'deployment_updates' + + deployment_id = foreign_key(Deployment.id) + execution_id = foreign_key(Execution.id, nullable=True) + _private_fields = ['execution_id', 'deployment_id'] + + created_at = Column(DateTime, nullable=False, index=True) + deployment_plan = Column(Dict, nullable=False) + deployment_update_node_instances = Column(Dict) + deployment_update_deployment = Column(Dict) + deployment_update_nodes = Column(Dict) + modified_entity_ids = Column(Dict) + state = Column(Text) + + @declared_attr + def execution(cls): + return one_to_many_relationship(cls, Execution, cls.execution_id) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + def to_dict(self, suppress_error=False, **kwargs): + dep_update_dict = super(DeploymentUpdate, self).to_dict(suppress_error) + # Taking care of the fact the DeploymentSteps are objects + dep_update_dict['steps'] = [step.to_dict() for step in self.steps] + return dep_update_dict + + +class DeploymentUpdateStep(SQLModelBase): """ - A Model which represents a deployment update step + Deployment update step model representation. """ - id = Field(type=basestring, default=uuid_generator) - action = Field(type=basestring, choices=ACTION_TYPES) - entity_type = Field(type=basestring, choices=ENTITY_TYPES) - entity_id = Field(type=basestring) - supported = Field(type=bool, default=True) + __tablename__ = 'deployment_update_steps' + _action_types = namedtuple('ACTION_TYPES', 'ADD, REMOVE, MODIFY') + ACTION_TYPES = _action_types(ADD='add', REMOVE='remove', MODIFY='modify') + _entity_types = namedtuple( + 'ENTITY_TYPES', + 'NODE, RELATIONSHIP, PROPERTY, OPERATION, WORKFLOW, OUTPUT, DESCRIPTION, GROUP, ' + 'POLICY_TYPE, POLICY_TRIGGER, PLUGIN') + ENTITY_TYPES = _entity_types( + NODE='node', + RELATIONSHIP='relationship', + PROPERTY='property', + OPERATION='operation', + WORKFLOW='workflow', + OUTPUT='output', + DESCRIPTION='description', + GROUP='group', + POLICY_TYPE='policy_type', + POLICY_TRIGGER='policy_trigger', + PLUGIN='plugin' + ) + + deployment_update_id = foreign_key(DeploymentUpdate.id) + _private_fields = ['deployment_update_id'] + + action = Column(Enum(*ACTION_TYPES, name='action_type'), nullable=False) + entity_id = Column(Text, nullable=False) + entity_type = Column(Enum(*ENTITY_TYPES, name='entity_type'), nullable=False) + + @declared_attr + def deployment_update(cls): + return one_to_many_relationship(cls, + DeploymentUpdate, + cls.deployment_update_id, + backreference='steps') def __hash__(self): return hash((self.id, self.entity_id)) @@ -148,265 +288,225 @@ class DeploymentUpdateStep(Model): return False -class DeploymentUpdate(Model): +class DeploymentModification(SQLModelBase): """ - A Model which represents a deployment update + Deployment modification model representation. """ - INITIALIZING = 'initializing' - SUCCESSFUL = 'successful' - UPDATING = 'updating' - FINALIZING = 'finalizing' - EXECUTING_WORKFLOW = 'executing_workflow' - FAILED = 'failed' + __tablename__ = 'deployment_modifications' - STATES = [ - INITIALIZING, - SUCCESSFUL, - UPDATING, - FINALIZING, - EXECUTING_WORKFLOW, - FAILED, - ] - - # '{0}-{1}'.format(kwargs['deployment_id'], uuid4()) - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - state = Field(type=basestring, choices=STATES, default=INITIALIZING) - deployment_plan = Field() - deployment_update_nodes = Field(default=None) - deployment_update_node_instances = Field(default=None) - deployment_update_deployment = Field(default=None) - modified_entity_ids = Field(default=None) - execution_id = Field(type=basestring) - steps = IterPointerField(type=DeploymentUpdateStep, default=()) - - -class Execution(Model): - """ - A Model which represents an execution - """ + STARTED = 'started' + FINISHED = 'finished' + ROLLEDBACK = 'rolledback' - class _Validation(object): - - @staticmethod - def execution_status_transition_validation(_, value, instance): - """Validation function that verifies execution status transitions are OK""" - try: - current_status = instance.status - except AttributeError: - return - valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, []) - if current_status != value and value not in valid_transitions: - raise ValueError('Cannot change execution status from {current} to {new}'.format( - current=current_status, - new=value)) + STATES = [STARTED, FINISHED, ROLLEDBACK] + END_STATES = [FINISHED, ROLLEDBACK] - TERMINATED = 'terminated' - FAILED = 'failed' - CANCELLED = 'cancelled' - PENDING = 'pending' - STARTED = 'started' - CANCELLING = 'cancelling' - STATES = ( - TERMINATED, - FAILED, - CANCELLED, - PENDING, - STARTED, - CANCELLING, - ) - END_STATES = [TERMINATED, FAILED, CANCELLED] - ACTIVE_STATES = [state for state in STATES if state not in END_STATES] - VALID_TRANSITIONS = { - PENDING: [STARTED, CANCELLED], - STARTED: END_STATES + [CANCELLING], - CANCELLING: END_STATES - } + deployment_id = foreign_key(Deployment.id) + _private_fields = ['deployment_id'] - id = Field(type=basestring, default=uuid_generator) - status = Field(type=basestring, choices=STATES, - validation_func=_Validation.execution_status_transition_validation) - deployment_id = Field(type=basestring) - workflow_id = Field(type=basestring) - blueprint_id = Field(type=basestring) - created_at = Field(type=datetime, default=datetime.utcnow) - started_at = Field(type=datetime, default=None) - ended_at = Field(type=datetime, default=None) - error = Field(type=basestring, default=None) - parameters = Field() + context = Column(Dict) + created_at = Column(DateTime, nullable=False, index=True) + ended_at = Column(DateTime, index=True) + modified_nodes = Column(Dict) + node_instances = Column(Dict) + status = Column(Enum(*STATES, name='deployment_modification_status')) + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, + Deployment, + cls.deployment_id, + backreference='modifications') -class Relationship(Model): + +class Node(SQLModelBase): """ - A Model which represents a relationship + Node model representation. """ - id = Field(type=basestring, default=uuid_generator) - source_id = Field(type=basestring) - target_id = Field(type=basestring) - source_interfaces = Field(type=dict) - source_operations = Field(type=dict) - target_interfaces = Field(type=dict) - target_operations = Field(type=dict) - type = Field(type=basestring) - type_hierarchy = Field(type=list) - properties = Field(type=dict) - - -class Node(Model): + __tablename__ = 'nodes' + + # See base class for an explanation on these properties + is_id_unique = False + + name = Column(Text, index=True) + _private_fields = ['deployment_id', 'host_id'] + deployment_id = foreign_key(Deployment.id) + host_id = foreign_key('nodes.id', nullable=True) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + deploy_number_of_instances = Column(Integer, nullable=False) + # TODO: This probably should be a foreign key, but there's no guarantee + # in the code, currently, that the host will be created beforehand + max_number_of_instances = Column(Integer, nullable=False) + min_number_of_instances = Column(Integer, nullable=False) + number_of_instances = Column(Integer, nullable=False) + planned_number_of_instances = Column(Integer, nullable=False) + plugins = Column(Dict) + plugins_to_install = Column(Dict) + properties = Column(Dict) + operations = Column(Dict) + type = Column(Text, nullable=False, index=True) + type_hierarchy = Column(List) + + @declared_attr + def host(cls): + return relationship_to_self(cls, cls.host_id, cls.id) + + +class Relationship(SQLModelBase): """ - A Model which represents a node + Relationship model representation. """ - id = Field(type=basestring, default=uuid_generator) - blueprint_id = Field(type=basestring) - type = Field(type=basestring) - type_hierarchy = Field() - number_of_instances = Field(type=int) - planned_number_of_instances = Field(type=int) - deploy_number_of_instances = Field(type=int) - host_id = Field(type=basestring, default=None) - properties = Field(type=dict) - operations = Field(type=dict) - plugins = Field(type=list, default=()) - relationships = IterPointerField(type=Relationship) - plugins_to_install = Field(type=list, default=()) - min_number_of_instances = Field(type=int) - max_number_of_instances = Field(type=int) - - def relationships_by_target(self, target_id): - """ - Retreives all of the relationship by target. - :param target_id: the node id of the target of the relationship - :yields: a relationship which target and node with the specified target_id - """ - for relationship in self.relationships: - if relationship.target_id == target_id: - yield relationship - # todo: maybe add here Exception if isn't exists (didn't yield one's) + __tablename__ = 'relationships' + _private_fields = ['source_node_id', 'target_node_id'] -class RelationshipInstance(Model): - """ - A Model which represents a relationship instance - """ - id = Field(type=basestring, default=uuid_generator) - target_id = Field(type=basestring) - target_name = Field(type=basestring) - source_id = Field(type=basestring) - source_name = Field(type=basestring) - type = Field(type=basestring) - relationship = PointerField(type=Relationship) + source_node_id = foreign_key(Node.id) + target_node_id = foreign_key(Node.id) + + @declared_attr + def source_node(cls): + return one_to_many_relationship(cls, + Node, + cls.source_node_id, + 'outbound_relationships') + + @declared_attr + def target_node(cls): + return one_to_many_relationship(cls, + Node, + cls.target_node_id, + 'inbound_relationships') + source_interfaces = Column(Dict) + source_operations = Column(Dict, nullable=False) + target_interfaces = Column(Dict) + target_operations = Column(Dict, nullable=False) + type = Column(String, nullable=False) + type_hierarchy = Column(List) + properties = Column(Dict) -class NodeInstance(Model): + +class NodeInstance(SQLModelBase): """ - A Model which represents a node instance + Node instance model representation. """ - # todo: add statuses - UNINITIALIZED = 'uninitialized' - INITIALIZING = 'initializing' - CREATING = 'creating' - CONFIGURING = 'configuring' - STARTING = 'starting' - DELETED = 'deleted' - STOPPING = 'stopping' - DELETING = 'deleting' - STATES = ( - UNINITIALIZED, - INITIALIZING, - CREATING, - CONFIGURING, - STARTING, - DELETED, - STOPPING, - DELETING - ) + __tablename__ = 'node_instances' - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - runtime_properties = Field(type=dict) - state = Field(type=basestring, choices=STATES, default=UNINITIALIZED) - version = Field(type=(basestring, NoneType)) - relationship_instances = IterPointerField(type=RelationshipInstance) - node = PointerField(type=Node) - host_id = Field(type=basestring, default=None) - scaling_groups = Field(default=()) - - def relationships_by_target(self, target_id): - """ - Retreives all of the relationship by target. - :param target_id: the instance id of the target of the relationship - :yields: a relationship instance which target and node with the specified target_id - """ - for relationship_instance in self.relationship_instances: - if relationship_instance.target_id == target_id: - yield relationship_instance - # todo: maybe add here Exception if isn't exists (didn't yield one's) + node_id = foreign_key(Node.id) + deployment_id = foreign_key(Deployment.id) + host_id = foreign_key('node_instances.id', nullable=True) + + _private_fields = ['node_id', 'host_id'] + + name = Column(Text, index=True) + runtime_properties = Column(Dict) + scaling_groups = Column(Dict) + state = Column(Text, nullable=False) + version = Column(Integer, default=1) + + @declared_attr + def deployment(cls): + return one_to_many_relationship(cls, Deployment, cls.deployment_id) + + @declared_attr + def node(cls): + return one_to_many_relationship(cls, Node, cls.node_id) + @declared_attr + def host(cls): + return relationship_to_self(cls, cls.host_id, cls.id) -class DeploymentModification(Model): + +class RelationshipInstance(SQLModelBase): """ - A Model which represents a deployment modification + Relationship instance model representation. """ - STARTED = 'started' - FINISHED = 'finished' - ROLLEDBACK = 'rolledback' - END_STATES = [FINISHED, ROLLEDBACK] + __tablename__ = 'relationship_instances' + + relationship_id = foreign_key(Relationship.id) + source_node_instance_id = foreign_key(NodeInstance.id) + target_node_instance_id = foreign_key(NodeInstance.id) + + _private_fields = ['relationship_storage_id', + 'source_node_instance_id', + 'target_node_instance_id'] - id = Field(type=basestring, default=uuid_generator) - deployment_id = Field(type=basestring) - modified_nodes = Field(type=(dict, NoneType)) - added_and_related = IterPointerField(type=NodeInstance) - removed_and_related = IterPointerField(type=NodeInstance) - extended_and_related = IterPointerField(type=NodeInstance) - reduced_and_related = IterPointerField(type=NodeInstance) - # before_modification = IterPointerField(type=NodeInstance) - status = Field(type=basestring, choices=(STARTED, FINISHED, ROLLEDBACK)) - created_at = Field(type=datetime) - ended_at = Field(type=(datetime, NoneType)) - context = Field() - - -class ProviderContext(Model): + @declared_attr + def source_node_instance(cls): + return one_to_many_relationship(cls, + NodeInstance, + cls.source_node_instance_id, + 'outbound_relationship_instances') + + @declared_attr + def target_node_instance(cls): + return one_to_many_relationship(cls, + NodeInstance, + cls.target_node_instance_id, + 'inbound_relationship_instances') + + @declared_attr + def relationship(cls): + return one_to_many_relationship(cls, Relationship, cls.relationship_id) + + +class ProviderContext(SQLModelBase): """ - A Model which represents a provider context + Provider context model representation. """ - id = Field(type=basestring, default=uuid_generator) - context = Field(type=dict) - name = Field(type=basestring) + __tablename__ = 'provider_context' + + name = Column(Text, nullable=False) + context = Column(Dict, nullable=False) -class Plugin(Model): +class Plugin(SQLModelBase): """ - A Model which represents a plugin + Plugin model representation. """ - id = Field(type=basestring, default=uuid_generator) - package_name = Field(type=basestring) - archive_name = Field(type=basestring) - package_source = Field(type=dict) - package_version = Field(type=basestring) - supported_platform = Field(type=basestring) - distribution = Field(type=basestring) - distribution_version = Field(type=basestring) - distribution_release = Field(type=basestring) - wheels = Field() - excluded_wheels = Field() - supported_py_versions = Field(type=list) - uploaded_at = Field(type=datetime) - - -class Task(Model): + __tablename__ = 'plugins' + + archive_name = Column(Text, nullable=False, index=True) + distribution = Column(Text) + distribution_release = Column(Text) + distribution_version = Column(Text) + excluded_wheels = Column(Dict) + package_name = Column(Text, nullable=False, index=True) + package_source = Column(Text) + package_version = Column(Text) + supported_platform = Column(Dict) + supported_py_versions = Column(Dict) + uploaded_at = Column(DateTime, nullable=False, index=True) + wheels = Column(Dict, nullable=False) + + +class Task(SQLModelBase): """ A Model which represents an task """ - class _Validation(object): + __tablename__ = 'task' + node_instance_id = foreign_key(NodeInstance.id, nullable=True) + relationship_instance_id = foreign_key(RelationshipInstance.id, nullable=True) + execution_id = foreign_key(Execution.id, nullable=True) + + _private_fields = ['node_instance_id', + 'relationship_instance_id', + 'execution_id'] - @staticmethod - def validate_max_attempts(_, value, *args): - """Validates that max attempts is either -1 or a positive number""" - if value < 1 and value != Task.INFINITE_RETRIES: - raise ValueError('Max attempts can be either -1 (infinite) or any positive number. ' - 'Got {value}'.format(value=value)) + @declared_attr + def node_instance(cls): + return one_to_many_relationship(cls, NodeInstance, cls.node_instance_id) + + @declared_attr + def relationship_instance(cls): + return one_to_many_relationship(cls, + RelationshipInstance, + cls.relationship_instance_id) PENDING = 'pending' RETRYING = 'retrying' @@ -422,23 +522,51 @@ class Task(Model): SUCCESS, FAILED, ) + WAIT_STATES = [PENDING, RETRYING] END_STATES = [SUCCESS, FAILED] + + @orm.validates('max_attempts') + def validate_max_attempts(self, _, value): # pylint: disable=no-self-use + """Validates that max attempts is either -1 or a positive number""" + if value < 1 and value != Task.INFINITE_RETRIES: + raise ValueError('Max attempts can be either -1 (infinite) or any positive number. ' + 'Got {value}'.format(value=value)) + return value + INFINITE_RETRIES = -1 - id = Field(type=basestring, default=uuid_generator) - status = Field(type=basestring, choices=STATES, default=PENDING) - execution_id = Field(type=basestring) - due_at = Field(type=datetime, default=datetime.utcnow) - started_at = Field(type=datetime, default=None) - ended_at = Field(type=datetime, default=None) - max_attempts = Field(type=int, default=1, validation_func=_Validation.validate_max_attempts) - retry_count = Field(type=int, default=0) - retry_interval = Field(type=(int, float), default=0) - ignore_failure = Field(type=bool, default=False) + status = Column(Enum(*STATES), name='status', default=PENDING) + + due_at = Column(DateTime, default=datetime.utcnow) + started_at = Column(DateTime, default=None) + ended_at = Column(DateTime, default=None) + max_attempts = Column(Integer, default=1) + retry_count = Column(Integer, default=0) + retry_interval = Column(Float, default=0) + ignore_failure = Column(Boolean, default=False) # Operation specific fields - name = Field(type=basestring) - operation_mapping = Field(type=basestring) - actor = Field() - inputs = Field(type=dict, default=lambda: {}) + name = Column(String) + operation_mapping = Column(String) + inputs = Column(Dict) + + @declared_attr + def execution(cls): + return one_to_many_relationship(cls, Execution, cls.execution_id) + + @property + def actor(self): + """ + Return the actor of the task + :return: + """ + return self.node_instance or self.relationship_instance + + @classmethod + def as_node_instance(cls, instance_id, **kwargs): + return cls(node_instance_id=instance_id, **kwargs) + + @classmethod + def as_relationship_instance(cls, instance_id, **kwargs): + return cls(relationship_instance_id=instance_id, **kwargs)
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/sql_mapi.py ---------------------------------------------------------------------- diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py new file mode 100644 index 0000000..cde40c2 --- /dev/null +++ b/aria/storage/sql_mapi.py @@ -0,0 +1,382 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SQLAlchemy based MAPI +""" + +from sqlalchemy.exc import SQLAlchemyError + +from aria.utils.collections import OrderedDict +from aria.storage import ( + api, + exceptions +) + + +class SQLAlchemyModelAPI(api.ModelAPI): + """ + SQL based MAPI. + """ + + def __init__(self, + engine, + session, + **kwargs): + super(SQLAlchemyModelAPI, self).__init__(**kwargs) + self._engine = engine + self._session = session + + def get(self, entry_id, include=None, **kwargs): + """Return a single result based on the model class and element ID + """ + query = self._get_query(include, {'id': entry_id}) + result = query.first() + + if not result: + raise exceptions.StorageError( + 'Requested {0} with ID `{1}` was not found' + .format(self.model_cls.__name__, entry_id) + ) + return result + + def get_by_name(self, entry_name, include=None, **kwargs): + assert hasattr(self.model_cls, 'name') + result = self.list(include=include, filters={'name': entry_name}) + if not result: + raise exceptions.StorageError( + 'Requested {0} with NAME `{1}` was not found' + .format(self.model_cls.__name__, entry_name) + ) + elif len(result) > 1: + raise exceptions.StorageError( + 'Requested {0} with NAME `{1}` returned more than 1 value' + .format(self.model_cls.__name__, entry_name) + ) + else: + return result[0] + + def list(self, + include=None, + filters=None, + pagination=None, + sort=None, + **kwargs): + query = self._get_query(include, filters, sort) + + results, total, size, offset = self._paginate(query, pagination) + + return ListResult( + items=results, + metadata=dict(total=total, + size=size, + offset=offset) + ) + + def iter(self, + include=None, + filters=None, + sort=None, + **kwargs): + """Return a (possibly empty) list of `model_class` results + """ + return iter(self._get_query(include, filters, sort)) + + def put(self, entry, **kwargs): + """Create a `model_class` instance from a serializable `model` object + + :param entry: A dict with relevant kwargs, or an instance of a class + that has a `to_dict` method, and whose attributes match the columns + of `model_class` (might also my just an instance of `model_class`) + :return: An instance of `model_class` + """ + self._session.add(entry) + self._safe_commit() + return entry + + def delete(self, entry, **kwargs): + """Delete a single result based on the model class and element ID + """ + self._load_relationships(entry) + self._session.delete(entry) + self._safe_commit() + return entry + + def update(self, entry, **kwargs): + """Add `instance` to the DB session, and attempt to commit + + :return: The updated instance + """ + return self.put(entry) + + def refresh(self, entry): + """Reload the instance with fresh information from the DB + + :param entry: Instance to be re-loaded from the DB + :return: The refreshed instance + """ + self._session.refresh(entry) + self._load_relationships(entry) + return entry + + def _destroy_connection(self): + pass + + def _establish_connection(self): + pass + + def create(self, checkfirst=True, **kwargs): + self.model_cls.__table__.create(self._engine, checkfirst=checkfirst) + + def drop(self): + """ + Drop the table from the storage. + :return: + """ + self.model_cls.__table__.drop(self._engine) + + def _safe_commit(self): + """Try to commit changes in the session. Roll back if exception raised + Excepts SQLAlchemy errors and rollbacks if they're caught + """ + try: + self._session.commit() + except (SQLAlchemyError, ValueError) as e: + self._session.rollback() + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + + def _get_base_query(self, include, joins): + """Create the initial query from the model class and included columns + + :param include: A (possibly empty) list of columns to include in + the query + :return: An SQLAlchemy AppenderQuery object + """ + # If only some columns are included, query through the session object + if include: + # Make sure that attributes come before association proxies + include.sort(key=lambda x: x.is_clause_element) + query = self._session.query(*include) + else: + # If all columns should be returned, query directly from the model + query = self._session.query(self.model_cls) + + if not self._skip_joining(joins, include): + for join_table in joins: + query = query.join(join_table) + + return query + + @staticmethod + def _get_joins(model_class, columns): + """Get a list of all the tables on which we need to join + + :param columns: A set of all columns involved in the query + """ + joins = [] # Using a list instead of a set because order is important + for column_name in columns: + column = getattr(model_class, column_name) + while not column.is_attribute: + column = column.remote_attr + if column.is_attribute: + join_class = column.class_ + else: + join_class = column.local_attr.class_ + + # Don't add the same class more than once + if join_class not in joins: + joins.append(join_class) + return joins + + @staticmethod + def _skip_joining(joins, include): + """Dealing with an edge case where the only included column comes from + an other table. In this case, we mustn't join on the same table again + + :param joins: A list of tables on which we're trying to join + :param include: The list of + :return: True if we need to skip joining + """ + if not joins: + return True + join_table_names = [t.__tablename__ for t in joins] + + if len(include) != 1: + return False + + column = include[0] + if column.is_clause_element: + table_name = column.element.table.name + else: + table_name = column.class_.__tablename__ + return table_name in join_table_names + + @staticmethod + def _sort_query(query, sort=None): + """Add sorting clauses to the query + + :param query: Base SQL query + :param sort: An optional dictionary where keys are column names to + sort by, and values are the order (asc/desc) + :return: An SQLAlchemy AppenderQuery object + """ + if sort: + for column, order in sort.items(): + if order == 'desc': + column = column.desc() + query = query.order_by(column) + return query + + def _filter_query(self, query, filters): + """Add filter clauses to the query + + :param query: Base SQL query + :param filters: An optional dictionary where keys are column names to + filter by, and values are values applicable for those columns (or lists + of such values) + :return: An SQLAlchemy AppenderQuery object + """ + return self._add_value_filter(query, filters) + + @staticmethod + def _add_value_filter(query, filters): + for column, value in filters.items(): + if isinstance(value, (list, tuple)): + query = query.filter(column.in_(value)) + else: + query = query.filter(column == value) + + return query + + def _get_query(self, + include=None, + filters=None, + sort=None): + """Get an SQL query object based on the params passed + + :param model_class: SQL DB table class + :param include: An optional list of columns to include in the query + :param filters: An optional dictionary where keys are column names to + filter by, and values are values applicable for those columns (or lists + of such values) + :param sort: An optional dictionary where keys are column names to + sort by, and values are the order (asc/desc) + :return: A sorted and filtered query with only the relevant + columns + """ + include, filters, sort, joins = self._get_joins_and_converted_columns( + include, filters, sort + ) + + query = self._get_base_query(include, joins) + query = self._filter_query(query, filters) + query = self._sort_query(query, sort) + return query + + def _get_joins_and_converted_columns(self, + include, + filters, + sort): + """Get a list of tables on which we need to join and the converted + `include`, `filters` and `sort` arguments (converted to actual SQLA + column/label objects instead of column names) + """ + include = include or [] + filters = filters or dict() + sort = sort or OrderedDict() + + all_columns = set(include) | set(filters.keys()) | set(sort.keys()) + joins = self._get_joins(self.model_cls, all_columns) + + include, filters, sort = self._get_columns_from_field_names( + include, filters, sort + ) + return include, filters, sort, joins + + def _get_columns_from_field_names(self, + include, + filters, + sort): + """Go over the optional parameters (include, filters, sort), and + replace column names with actual SQLA column objects + """ + include = [self._get_column(c) for c in include] + filters = dict((self._get_column(c), filters[c]) for c in filters) + sort = OrderedDict((self._get_column(c), sort[c]) for c in sort) + + return include, filters, sort + + def _get_column(self, column_name): + """Return the column on which an action (filtering, sorting, etc.) + would need to be performed. Can be either an attribute of the class, + or an association proxy linked to a relationship the class has + """ + column = getattr(self.model_cls, column_name) + if column.is_attribute: + return column + else: + # We need to get to the underlying attribute, so we move on to the + # next remote_attr until we reach one + while not column.remote_attr.is_attribute: + column = column.remote_attr + # Put a label on the remote attribute with the name of the column + return column.remote_attr.label(column_name) + + @staticmethod + def _paginate(query, pagination): + """Paginate the query by size and offset + + :param query: Current SQLAlchemy query object + :param pagination: An optional dict with size and offset keys + :return: A tuple with four elements: + - res ults: `size` items starting from `offset` + - the total count of items + - `size` [default: 0] + - `offset` [default: 0] + """ + if pagination: + size = pagination.get('size', 0) + offset = pagination.get('offset', 0) + total = query.order_by(None).count() # Fastest way to count + results = query.limit(size).offset(offset).all() + return results, total, size, offset + else: + results = query.all() + return results, len(results), 0, 0 + + @staticmethod + def _load_relationships(instance): + """A helper method used to overcome a problem where the relationships + that rely on joins aren't being loaded automatically + """ + for rel in instance.__mapper__.relationships: + getattr(instance, rel.key) + + +class ListResult(object): + """ + a ListResult contains results about the requested items. + """ + def __init__(self, items, metadata): + self.items = items + self.metadata = metadata + + def __len__(self): + return len(self.items) + + def __iter__(self): + return iter(self.items) + + def __getitem__(self, item): + return self.items[item] http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/structures.py ---------------------------------------------------------------------- diff --git a/aria/storage/structures.py b/aria/storage/structures.py index b02366e..8dbd2a9 100644 --- a/aria/storage/structures.py +++ b/aria/storage/structures.py @@ -27,281 +27,218 @@ classes: * Model - abstract model implementation. """ import json -from itertools import count -from uuid import uuid4 - -from .exceptions import StorageError -from ..logger import LoggerMixin -from ..utils.validation import ValidatorMixin - -__all__ = ( - 'uuid_generator', - 'Field', - 'IterField', - 'PointerField', - 'IterPointerField', - 'Model', - 'Storage', + +from sqlalchemy.ext.mutable import Mutable +from sqlalchemy.orm import relationship, backref +from sqlalchemy.ext.declarative import declarative_base +# pylint: disable=unused-import +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy import ( + schema, + VARCHAR, + ARRAY, + Column, + Integer, + Text, + DateTime, + Boolean, + Enum, + String, + PickleType, + Float, + TypeDecorator, + ForeignKey, + orm, ) +from aria.storage import exceptions + +Model = declarative_base() -def uuid_generator(): - """ - wrapper function which generates ids - """ - return str(uuid4()) +def foreign_key(foreign_key_column, nullable=False): + """Return a ForeignKey object with the relevant -class Field(ValidatorMixin): + :param foreign_key_column: Unique id column in the parent table + :param nullable: Should the column be allowed to remain empty """ - A single field implementation + return Column( + ForeignKey(foreign_key_column, ondelete='CASCADE'), + nullable=nullable + ) + + +def one_to_many_relationship(child_class, + parent_class, + foreign_key_column, + backreference=None): + """Return a one-to-many SQL relationship object + Meant to be used from inside the *child* object + + :param parent_class: Class of the parent table + :param child_class: Class of the child table + :param foreign_key_column: The column of the foreign key + :param backreference: The name to give to the reference to the child """ - NO_DEFAULT = 'NO_DEFAULT' - - try: - # python 3 syntax - _next_id = count().__next__ - except AttributeError: - # python 2 syntax - _next_id = count().next - _ATTRIBUTE_NAME = '_cache_{0}'.format - - def __init__( - self, - type=None, - choices=(), - validation_func=None, - default=NO_DEFAULT, - **kwargs): - """ - Simple field manager. + backreference = backreference or child_class.__tablename__ + return relationship( + parent_class, + primaryjoin=lambda: parent_class.id == foreign_key_column, + # The following line make sure that when the *parent* is + # deleted, all its connected children are deleted as well + backref=backref(backreference, cascade='all') + ) - :param type: possible type of the field. - :param choices: a set of possible field values. - :param default: default field value. - :param kwargs: kwargs to be passed to next in line classes. - """ - self.type = type - self.choices = choices - self.default = default - self.validation_func = validation_func - super(Field, self).__init__(**kwargs) - - def __get__(self, instance, owner): - if instance is None: - return self - field_name = self._field_name(instance) - try: - return getattr(instance, self._ATTRIBUTE_NAME(field_name)) - except AttributeError as exc: - if self.default == self.NO_DEFAULT: - raise AttributeError( - str(exc).replace(self._ATTRIBUTE_NAME(field_name), field_name)) - - default_value = self.default() if callable(self.default) else self.default - setattr(instance, self._ATTRIBUTE_NAME(field_name), default_value) - return default_value - - def __set__(self, instance, value): - field_name = self._field_name(instance) - self.validate_value(field_name, value, instance) - setattr(instance, self._ATTRIBUTE_NAME(field_name), value) - - def validate_value(self, name, value, instance): - """ - Validates the value of the field. - :param name: the name of the field. - :param value: the value of the field. - :param instance: the instance containing the field. - """ - if self.default != self.NO_DEFAULT and value == self.default: - return - if self.type: - self.validate_instance(name, value, self.type) - if self.choices: - self.validate_in_choice(name, value, self.choices) - if self.validation_func: - self.validation_func(name, value, instance) - - def _field_name(self, instance): - """ - retrieves the field name from the instance. - - :param Field instance: the instance which holds the field. - :return: name of the field - :rtype: basestring - """ - for name, member in vars(instance.__class__).iteritems(): - if member is self: - return name +def relationship_to_self(self_cls, parent_key, self_key): + return relationship( + self_cls, + foreign_keys=parent_key, + remote_side=self_key + ) -class IterField(Field): +class _MutableType(TypeDecorator): """ - Represents an iterable field. + Dict representation of type. """ - def __init__(self, **kwargs): - """ - Simple iterable field manager. - This field type don't have choices option. - - :param kwargs: kwargs to be passed to next in line classes. - """ - super(IterField, self).__init__(choices=(), **kwargs) + @property + def python_type(self): + raise NotImplementedError - def validate_value(self, name, values, *args): - """ - Validates the value of each iterable value. + def process_literal_param(self, value, dialect): + pass - :param name: the name of the field. - :param values: the values of the field. - """ - for value in values: - self.validate_instance(name, value, self.type) + impl = VARCHAR + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value -class PointerField(Field): - """ - A single pointer field implementation. - - Any PointerField points via id to another document. - """ + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value - def __init__(self, type, **kwargs): - assert issubclass(type, Model) - super(PointerField, self).__init__(type=type, **kwargs) +class _DictType(_MutableType): + @property + def python_type(self): + return dict -class IterPointerField(IterField, PointerField): - """ - An iterable pointers field. - Any IterPointerField points via id to other documents. - """ - pass +class _ListType(_MutableType): + @property + def python_type(self): + return list -class Model(object): +class _MutableDict(Mutable, dict): """ - Base class for all of the storage models. + Enables tracking for dict values. """ - id = None + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." - def __init__(self, **fields): - """ - Abstract class for any model in the storage. - The Initializer creates attributes according to the (keyword arguments) that given - Each value is validated according to the Field. - Each model has to have and ID Field. + if not isinstance(value, _MutableDict): + if isinstance(value, dict): + return _MutableDict(value) - :param fields: each item is validated and transformed into instance attributes. - """ - self._assert_model_have_id_field(**fields) - missing_fields, unexpected_fields = self._setup_fields(fields) + # this call will raise ValueError + try: + return Mutable.coerce(key, value) + except ValueError as e: + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + else: + return value - if missing_fields: - raise StorageError( - 'Model {name} got missing keyword arguments: {fields}'.format( - name=self.__class__.__name__, fields=missing_fields)) + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." - if unexpected_fields: - raise StorageError( - 'Model {name} got unexpected keyword arguments: {fields}'.format( - name=self.__class__.__name__, fields=unexpected_fields)) + dict.__setitem__(self, key, value) + self.changed() - def __repr__(self): - return '{name}(fields={0})'.format(sorted(self.fields), name=self.__class__.__name__) + def __delitem__(self, key): + "Detect dictionary del events and emit change events." - def __eq__(self, other): - return ( - isinstance(other, self.__class__) and - self.fields_dict == other.fields_dict) + dict.__delitem__(self, key) + self.changed() - @property - def fields(self): - """ - Iterates over the fields of the model. - :yields: the class's field name - """ - for name, field in vars(self.__class__).items(): - if isinstance(field, Field): - yield name - @property - def fields_dict(self): - """ - Transforms the instance attributes into a dict. +class _MutableList(Mutable, list): - :return: all fields in dict format. - :rtype dict - """ - return dict((name, getattr(self, name)) for name in self.fields) + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." - @property - def json(self): - """ - Transform the dict of attributes into json - :return: - """ - return json.dumps(self.fields_dict) + if not isinstance(value, _MutableList): + if isinstance(value, list): + return _MutableList(value) - @classmethod - def _assert_model_have_id_field(cls, **fields_initializer_values): - if not getattr(cls, 'id', None): - raise StorageError('Model {cls.__name__} must have id field'.format(cls=cls)) - - if cls.id.default == cls.id.NO_DEFAULT and 'id' not in fields_initializer_values: - raise StorageError( - 'Model {cls.__name__} is missing required ' - 'keyword-only argument: "id"'.format(cls=cls)) - - def _setup_fields(self, input_fields): - missing = [] - for field_name in self.fields: + # this call will raise ValueError try: - field_obj = input_fields.pop(field_name) - setattr(self, field_name, field_obj) - except KeyError: - field = getattr(self.__class__, field_name) - if field.default == field.NO_DEFAULT: - missing.append(field_name) + return Mutable.coerce(key, value) + except ValueError as e: + raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) + else: + return value + + def __setitem__(self, key, value): + list.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + list.__delitem__(self, key) + - unexpected_fields = input_fields.keys() - return missing, unexpected_fields +Dict = _MutableDict.as_mutable(_DictType) +List = _MutableList.as_mutable(_ListType) -class Storage(LoggerMixin): +class SQLModelBase(Model): """ - Represents the storage + Abstract base class for all SQL models that allows [de]serialization """ - def __init__(self, driver, items=(), **kwargs): - super(Storage, self).__init__(**kwargs) - self.driver = driver - self.registered = {} - for item in items: - self.register(item) - self.logger.debug('{name} object is ready: {0!r}'.format( - self, name=self.__class__.__name__)) + # SQLAlchemy syntax + __abstract__ = True - def __repr__(self): - return '{name}(driver={self.driver})'.format( - name=self.__class__.__name__, self=self) + # This would be overridden once the models are created. Created for pylint. + __table__ = None + + _private_fields = [] + + id = Column(Integer, primary_key=True, autoincrement=True) - def __getattr__(self, item): - try: - return self.registered[item] - except KeyError: - return super(Storage, self).__getattribute__(item) + def to_dict(self, suppress_error=False): + """Return a dict representation of the model - def setup(self): + :param suppress_error: If set to True, sets `None` to attributes that + it's unable to retrieve (e.g., if a relationship wasn't established + yet, and so it's impossible to access a property through it) """ - Setup and create all storage items + if suppress_error: + res = dict() + for field in self.fields(): + try: + field_value = getattr(self, field) + except AttributeError: + field_value = None + res[field] = field_value + else: + # Can't simply call here `self.to_response()` because inheriting + # class might override it, but we always need the same code here + res = dict((f, getattr(self, f)) for f in self.fields()) + return res + + @classmethod + def fields(cls): + """Return the list of field names for this table + + Mostly for backwards compatibility in the code (that uses `fields`) """ - for name, api in self.registered.iteritems(): - try: - api.create() - self.logger.debug( - 'setup {name} in storage {self!r}'.format(name=name, self=self)) - except StorageError: - pass + return set(cls.__table__.columns.keys()) - set(cls._private_fields) + + def __repr__(self): + return '<{0} id=`{1}`>'.format(self.__class__.__name__, self.id) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/utils/application.py ---------------------------------------------------------------------- diff --git a/aria/utils/application.py b/aria/utils/application.py index b1a7fcc..113e054 100644 --- a/aria/utils/application.py +++ b/aria/utils/application.py @@ -117,7 +117,7 @@ class StorageManager(LoggerMixin): updated_at=now, main_file_name=main_file_name, ) - self.model_storage.blueprint.store(blueprint) + self.model_storage.blueprint.put(blueprint) self.logger.debug('created blueprint model storage entry') def create_nodes_storage(self): @@ -138,7 +138,7 @@ class StorageManager(LoggerMixin): scalable = node_copy.pop('capabilities')['scalable']['properties'] for index, relationship in enumerate(node_copy['relationships']): relationship = self.model_storage.relationship.model_cls(**relationship) - self.model_storage.relationship.store(relationship) + self.model_storage.relationship.put(relationship) node_copy['relationships'][index] = relationship node_copy = self.model_storage.node.model_cls( @@ -149,7 +149,7 @@ class StorageManager(LoggerMixin): max_number_of_instances=scalable['max_instances'], number_of_instances=scalable['current_instances'], **node_copy) - self.model_storage.node.store(node_copy) + self.model_storage.node.put(node_copy) def create_deployment_storage(self): """ @@ -190,7 +190,7 @@ class StorageManager(LoggerMixin): created_at=now, updated_at=now ) - self.model_storage.deployment.store(deployment) + self.model_storage.deployment.put(deployment) self.logger.debug('created deployment model storage entry') def create_node_instances_storage(self): @@ -213,7 +213,7 @@ class StorageManager(LoggerMixin): type=relationship_instance['type'], target_id=relationship_instance['target_id']) relationship_instances.append(relationship_instance_model) - self.model_storage.relationship_instance.store(relationship_instance_model) + self.model_storage.relationship_instance.put(relationship_instance_model) node_instance_model = self.model_storage.node_instance.model_cls( node=node_model, @@ -224,7 +224,7 @@ class StorageManager(LoggerMixin): version='1.0', relationship_instances=relationship_instances) - self.model_storage.node_instance.store(node_instance_model) + self.model_storage.node_instance.put(node_instance_model) self.logger.debug('created node-instances model storage entries') def create_plugin_storage(self, plugin_id, source): @@ -258,7 +258,7 @@ class StorageManager(LoggerMixin): supported_py_versions=plugin.get('supported_python_versions'), uploaded_at=now ) - self.model_storage.plugin.store(plugin) + self.model_storage.plugin.put(plugin) self.logger.debug('created plugin model storage entry') http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/requirements.txt ---------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index e6d5393..7e87c67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ Jinja2==2.8 shortuuid==0.4.3 CacheControl[filecache]==0.11.6 clint==0.5.1 +SQLAlchemy==1.1.4 \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/mock/context.py ---------------------------------------------------------------------- diff --git a/tests/mock/context.py b/tests/mock/context.py index 5fda07e..1904140 100644 --- a/tests/mock/context.py +++ b/tests/mock/context.py @@ -15,23 +15,53 @@ from aria import application_model_storage from aria.orchestrator import context +from aria.storage.sql_mapi import SQLAlchemyModelAPI from . import models -from ..storage import InMemoryModelDriver -def simple(**kwargs): - storage = application_model_storage(InMemoryModelDriver()) - storage.setup() - storage.blueprint.store(models.get_blueprint()) - storage.deployment.store(models.get_deployment()) +def simple(api_kwargs, **kwargs): + model_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs) + blueprint = models.get_blueprint() + model_storage.blueprint.put(blueprint) + deployment = models.get_deployment(blueprint) + model_storage.deployment.put(deployment) + + ################################################################################# + # Creating a simple deployment with node -> node as a graph + + dependency_node = models.get_dependency_node(deployment) + model_storage.node.put(dependency_node) + storage_dependency_node = model_storage.node.get(dependency_node.id) + + dependency_node_instance = models.get_dependency_node_instance(storage_dependency_node) + model_storage.node_instance.put(dependency_node_instance) + storage_dependency_node_instance = model_storage.node_instance.get(dependency_node_instance.id) + + dependent_node = models.get_dependent_node(deployment) + model_storage.node.put(dependent_node) + storage_dependent_node = model_storage.node.get(dependent_node.id) + + dependent_node_instance = models.get_dependent_node_instance(storage_dependent_node) + model_storage.node_instance.put(dependent_node_instance) + storage_dependent_node_instance = model_storage.node_instance.get(dependent_node_instance.id) + + relationship = models.get_relationship(storage_dependent_node, storage_dependency_node) + model_storage.relationship.put(relationship) + storage_relationship = model_storage.relationship.get(relationship.id) + relationship_instance = models.get_relationship_instance( + relationship=storage_relationship, + target_instance=storage_dependency_node_instance, + source_instance=storage_dependent_node_instance + ) + model_storage.relationship_instance.put(relationship_instance) + final_kwargs = dict( name='simple_context', - model_storage=storage, + model_storage=model_storage, resource_storage=None, - deployment_id=models.DEPLOYMENT_ID, - workflow_id=models.WORKFLOW_ID, - execution_id=models.EXECUTION_ID, + deployment_id=deployment.id, + workflow_name=models.WORKFLOW_NAME, task_max_attempts=models.TASK_MAX_ATTEMPTS, task_retry_interval=models.TASK_RETRY_INTERVAL ) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/mock/models.py ---------------------------------------------------------------------- diff --git a/tests/mock/models.py b/tests/mock/models.py index 327b0b9..e2e3d2f 100644 --- a/tests/mock/models.py +++ b/tests/mock/models.py @@ -19,24 +19,24 @@ from aria.storage import models from . import operations -DEPLOYMENT_ID = 'test_deployment_id' -BLUEPRINT_ID = 'test_blueprint_id' -WORKFLOW_ID = 'test_workflow_id' -EXECUTION_ID = 'test_execution_id' +DEPLOYMENT_NAME = 'test_deployment_id' +BLUEPRINT_NAME = 'test_blueprint_id' +WORKFLOW_NAME = 'test_workflow_id' +EXECUTION_NAME = 'test_execution_id' TASK_RETRY_INTERVAL = 1 TASK_MAX_ATTEMPTS = 1 -DEPENDENCY_NODE_ID = 'dependency_node' -DEPENDENCY_NODE_INSTANCE_ID = 'dependency_node_instance' -DEPENDENT_NODE_ID = 'dependent_node' -DEPENDENT_NODE_INSTANCE_ID = 'dependent_node_instance' +DEPENDENCY_NODE_NAME = 'dependency_node' +DEPENDENCY_NODE_INSTANCE_NAME = 'dependency_node_instance' +DEPENDENT_NODE_NAME = 'dependent_node' +DEPENDENT_NODE_INSTANCE_NAME = 'dependent_node_instance' +RELATIONSHIP_NAME = 'relationship' +RELATIONSHIP_INSTANCE_NAME = 'relationship_instance' -def get_dependency_node(): +def get_dependency_node(deployment): return models.Node( - id=DEPENDENCY_NODE_ID, - host_id=DEPENDENCY_NODE_ID, - blueprint_id=BLUEPRINT_ID, + name=DEPENDENCY_NODE_NAME, type='test_node_type', type_hierarchy=[], number_of_instances=1, @@ -44,28 +44,28 @@ def get_dependency_node(): deploy_number_of_instances=1, properties={}, operations=dict((key, {}) for key in operations.NODE_OPERATIONS), - relationships=[], min_number_of_instances=1, max_number_of_instances=1, + deployment_id=deployment.id ) -def get_dependency_node_instance(dependency_node=None): +def get_dependency_node_instance(dependency_node): return models.NodeInstance( - id=DEPENDENCY_NODE_INSTANCE_ID, - host_id=DEPENDENCY_NODE_INSTANCE_ID, - deployment_id=DEPLOYMENT_ID, + name=DEPENDENCY_NODE_INSTANCE_NAME, runtime_properties={'ip': '1.1.1.1'}, version=None, - relationship_instances=[], - node=dependency_node or get_dependency_node() + node_id=dependency_node.id, + deployment_id=dependency_node.deployment.id, + state='', + scaling_groups={} ) def get_relationship(source=None, target=None): return models.Relationship( - source_id=source.id if source is not None else DEPENDENT_NODE_ID, - target_id=target.id if target is not None else DEPENDENCY_NODE_ID, + source_node_id=source.id, + target_node_id=target.id, source_interfaces={}, source_operations=dict((key, {}) for key in operations.RELATIONSHIP_OPERATIONS), target_interfaces={}, @@ -76,23 +76,18 @@ def get_relationship(source=None, target=None): ) -def get_relationship_instance(source_instance=None, target_instance=None, relationship=None): +def get_relationship_instance(source_instance, target_instance, relationship): return models.RelationshipInstance( - target_id=target_instance.id if target_instance else DEPENDENCY_NODE_INSTANCE_ID, - target_name='test_target_name', - source_id=source_instance.id if source_instance else DEPENDENT_NODE_INSTANCE_ID, - source_name='test_source_name', - type='some_type', - relationship=relationship or get_relationship(target_instance.node - if target_instance else None) + relationship_id=relationship.id, + target_node_instance_id=target_instance.id, + source_node_instance_id=source_instance.id, ) -def get_dependent_node(relationship=None): +def get_dependent_node(deployment): return models.Node( - id=DEPENDENT_NODE_ID, - host_id=DEPENDENT_NODE_ID, - blueprint_id=BLUEPRINT_ID, + name=DEPENDENT_NODE_NAME, + deployment_id=deployment.id, type='test_node_type', type_hierarchy=[], number_of_instances=1, @@ -100,21 +95,20 @@ def get_dependent_node(relationship=None): deploy_number_of_instances=1, properties={}, operations=dict((key, {}) for key in operations.NODE_OPERATIONS), - relationships=[relationship or get_relationship()], min_number_of_instances=1, max_number_of_instances=1, ) -def get_dependent_node_instance(relationship_instance=None, dependent_node=None): +def get_dependent_node_instance(dependent_node): return models.NodeInstance( - id=DEPENDENT_NODE_INSTANCE_ID, - host_id=DEPENDENT_NODE_INSTANCE_ID, - deployment_id=DEPLOYMENT_ID, + name=DEPENDENT_NODE_INSTANCE_NAME, runtime_properties={}, version=None, - relationship_instances=[relationship_instance or get_relationship_instance()], - node=dependent_node or get_dependency_node() + node_id=dependent_node.id, + deployment_id=dependent_node.deployment.id, + state='', + scaling_groups={} ) @@ -122,7 +116,7 @@ def get_blueprint(): now = datetime.now() return models.Blueprint( plan={}, - id=BLUEPRINT_ID, + name=BLUEPRINT_NAME, description=None, created_at=now, updated_at=now, @@ -130,25 +124,31 @@ def get_blueprint(): ) -def get_execution(): +def get_execution(deployment): return models.Execution( - id=EXECUTION_ID, + deployment_id=deployment.id, + blueprint_id=deployment.blueprint.id, status=models.Execution.STARTED, - deployment_id=DEPLOYMENT_ID, - workflow_id=WORKFLOW_ID, - blueprint_id=BLUEPRINT_ID, + workflow_name=WORKFLOW_NAME, started_at=datetime.utcnow(), parameters=None ) -def get_deployment(): +def get_deployment(blueprint): now = datetime.utcnow() return models.Deployment( - id=DEPLOYMENT_ID, - description=None, + name=DEPLOYMENT_NAME, + blueprint_id=blueprint.id, + description='', created_at=now, updated_at=now, - blueprint_id=BLUEPRINT_ID, - workflows={} + workflows={}, + inputs={}, + groups={}, + permalink='', + policy_triggers={}, + policy_types={}, + outputs={}, + scaling_groups={}, ) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_operation.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/context/test_operation.py b/tests/orchestrator/context/test_operation.py index 6b3e28d..b5f52a3 100644 --- a/tests/orchestrator/context/test_operation.py +++ b/tests/orchestrator/context/test_operation.py @@ -23,7 +23,7 @@ from aria.orchestrator import context from aria.orchestrator.workflows import api from aria.orchestrator.workflows.executor import thread -from tests import mock +from tests import mock, storage from . import ( op_path, op_name, @@ -34,8 +34,10 @@ global_test_holder = {} @pytest.fixture -def ctx(): - return mock.context.simple() +def ctx(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) @pytest.fixture @@ -50,14 +52,13 @@ def executor(): def test_node_operation_task_execution(ctx, executor): operation_name = 'aria.interfaces.lifecycle.create' - node = mock.models.get_dependency_node() + node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) node.operations[operation_name] = { 'operation': op_path(my_operation, module_path=__name__) } - node_instance = mock.models.get_dependency_node_instance(node) - ctx.model.node.store(node) - ctx.model.node_instance.store(node_instance) + ctx.model.node.update(node) + node_instance = ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) inputs = {'putput': True} @@ -90,26 +91,19 @@ def test_node_operation_task_execution(ctx, executor): def test_relationship_operation_task_execution(ctx, executor): operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure' - - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance() - relationship = mock.models.get_relationship(target=dependency_node) + relationship = ctx.model.relationship.list()[0] relationship.source_operations[operation_name] = { 'operation': op_path(my_operation, module_path=__name__) } - relationship_instance = mock.models.get_relationship_instance( - target_instance=dependency_node_instance, - relationship=relationship) - dependent_node = mock.models.get_dependent_node() - dependent_node_instance = mock.models.get_dependent_node_instance( - relationship_instance=relationship_instance, - dependent_node=dependency_node) - ctx.model.node.store(dependency_node) - ctx.model.node_instance.store(dependency_node_instance) - ctx.model.relationship.store(relationship) - ctx.model.relationship_instance.store(relationship_instance) - ctx.model.node.store(dependent_node) - ctx.model.node_instance.store(dependent_node_instance) + ctx.model.relationship.update(relationship) + relationship_instance = ctx.model.relationship_instance.list()[0] + + dependency_node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + dependency_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + dependent_node = ctx.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) + dependent_node_instance = \ + ctx.model.node_instance.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME) inputs = {'putput': True} @@ -146,11 +140,49 @@ def test_relationship_operation_task_execution(ctx, executor): assert operation_context.source_node_instance == dependent_node_instance +def test_invalid_task_operation_id(ctx, executor): + """ + Checks that the right id is used. The task created with id == 1, thus running the task on + node_instance with id == 2. will check that indeed the node_instance uses the correct id. + :param ctx: + :param executor: + :return: + """ + operation_name = 'aria.interfaces.lifecycle.create' + other_node_instance, node_instance = ctx.model.node_instance.list() + assert other_node_instance.id == 1 + assert node_instance.id == 2 + + node = node_instance.node + node.operations[operation_name] = { + 'operation': op_path(get_node_instance_id, module_path=__name__) + + } + ctx.model.node.update(node) + + @workflow + def basic_workflow(graph, **_): + graph.add_tasks( + api.task.OperationTask.node_instance(name=operation_name, instance=node_instance) + ) + + execute(workflow_func=basic_workflow, workflow_context=ctx, executor=executor) + + op_node_instance_id = global_test_holder[op_name(node_instance, operation_name)] + assert op_node_instance_id == node_instance.id + assert op_node_instance_id != other_node_instance.id + + @operation def my_operation(ctx, **_): global_test_holder[ctx.name] = ctx +@operation +def get_node_instance_id(ctx, **_): + global_test_holder[ctx.name] = ctx.node_instance.id + + @pytest.fixture(autouse=True) def cleanup(): global_test_holder.clear() http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_toolbelt.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/context/test_toolbelt.py b/tests/orchestrator/context/test_toolbelt.py index 547e62b..da46696 100644 --- a/tests/orchestrator/context/test_toolbelt.py +++ b/tests/orchestrator/context/test_toolbelt.py @@ -21,7 +21,7 @@ from aria.orchestrator.workflows import api from aria.orchestrator.workflows.executor import thread from aria.orchestrator.context.toolbelt import RelationshipToolBelt -from tests import mock +from tests import mock, storage from . import ( op_path, op_name, @@ -32,8 +32,10 @@ global_test_holder = {} @pytest.fixture -def workflow_context(): - return mock.context.simple() +def workflow_context(tmpdir): + context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield context + storage.release_sqlite_storage(context.model) @pytest.fixture @@ -45,63 +47,39 @@ def executor(): result.close() -def _create_simple_model_in_storage(workflow_context): - dependency_node = mock.models.get_dependency_node() - dependency_node_instance = mock.models.get_dependency_node_instance( - dependency_node=dependency_node) - relationship = mock.models.get_relationship(target=dependency_node) - relationship_instance = mock.models.get_relationship_instance( - target_instance=dependency_node_instance, relationship=relationship) - dependent_node = mock.models.get_dependent_node() - dependent_node_instance = mock.models.get_dependent_node_instance( - relationship_instance=relationship_instance, dependent_node=dependency_node) - workflow_context.model.node.store(dependency_node) - workflow_context.model.node_instance.store(dependency_node_instance) - workflow_context.model.relationship.store(relationship) - workflow_context.model.relationship_instance.store(relationship_instance) - workflow_context.model.node.store(dependent_node) - workflow_context.model.node_instance.store(dependent_node_instance) - return dependency_node, dependency_node_instance, \ - dependent_node, dependent_node_instance, \ - relationship, relationship_instance +def _get_elements(workflow_context): + dependency_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME) + dependency_node.host_id = dependency_node.id + workflow_context.model.node.update(dependency_node) + dependency_node_instance = workflow_context.model.node_instance.get_by_name( + mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + dependency_node_instance.host_id = dependency_node_instance.id + workflow_context.model.node_instance.update(dependency_node_instance) -def test_host_ip(workflow_context, executor): - operation_name = 'aria.interfaces.lifecycle.create' - dependency_node, dependency_node_instance, _, _, _, _ = \ - _create_simple_model_in_storage(workflow_context) - dependency_node.operations[operation_name] = { - 'operation': op_path(host_ip, module_path=__name__) - - } - workflow_context.model.node.store(dependency_node) - inputs = {'putput': True} - - @workflow - def basic_workflow(graph, **_): - graph.add_tasks( - api.task.OperationTask.node_instance( - instance=dependency_node_instance, - name=operation_name, - inputs=inputs - ) - ) + dependent_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME) + dependent_node.host_id = dependency_node.id + workflow_context.model.node.update(dependent_node) - execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor) + dependent_node_instance = workflow_context.model.node_instance.get_by_name( + mock.models.DEPENDENT_NODE_INSTANCE_NAME) + dependent_node_instance.host_id = dependent_node_instance.id + workflow_context.model.node_instance.update(dependent_node_instance) - assert global_test_holder.get('host_ip') == \ - dependency_node_instance.runtime_properties.get('ip') + relationship = workflow_context.model.relationship.list()[0] + relationship_instance = workflow_context.model.relationship_instance.list()[0] + return dependency_node, dependency_node_instance, dependent_node, dependent_node_instance, \ + relationship, relationship_instance -def test_dependent_node_instances(workflow_context, executor): +def test_host_ip(workflow_context, executor): operation_name = 'aria.interfaces.lifecycle.create' - dependency_node, dependency_node_instance, _, dependent_node_instance, _, _ = \ - _create_simple_model_in_storage(workflow_context) + dependency_node, dependency_node_instance, _, _, _, _ = _get_elements(workflow_context) dependency_node.operations[operation_name] = { - 'operation': op_path(dependent_nodes, module_path=__name__) + 'operation': op_path(host_ip, module_path=__name__) } - workflow_context.model.node.store(dependency_node) + workflow_context.model.node.put(dependency_node) inputs = {'putput': True} @workflow @@ -116,18 +94,18 @@ def test_dependent_node_instances(workflow_context, executor): execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor) - assert list(global_test_holder.get('dependent_node_instances', [])) == \ - list([dependent_node_instance]) + assert global_test_holder.get('host_ip') == \ + dependency_node_instance.runtime_properties.get('ip') def test_relationship_tool_belt(workflow_context, executor): operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure' _, _, _, _, relationship, relationship_instance = \ - _create_simple_model_in_storage(workflow_context) + _get_elements(workflow_context) relationship.source_operations[operation_name] = { 'operation': op_path(relationship_operation, module_path=__name__) } - workflow_context.model.relationship.store(relationship) + workflow_context.model.relationship.put(relationship) inputs = {'putput': True} @@ -152,17 +130,13 @@ def test_wrong_model_toolbelt(): with pytest.raises(RuntimeError): context.toolbelt(None) + @operation(toolbelt=True) def host_ip(toolbelt, **_): global_test_holder['host_ip'] = toolbelt.host_ip @operation(toolbelt=True) -def dependent_nodes(toolbelt, **_): - global_test_holder['dependent_node_instances'] = list(toolbelt.dependent_node_instances) - - -@operation(toolbelt=True) def relationship_operation(ctx, toolbelt, **_): global_test_holder[ctx.name] = toolbelt http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_workflow.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/context/test_workflow.py b/tests/orchestrator/context/test_workflow.py index 258f0c5..496c1ff 100644 --- a/tests/orchestrator/context/test_workflow.py +++ b/tests/orchestrator/context/test_workflow.py @@ -19,20 +19,19 @@ import pytest from aria import application_model_storage from aria.orchestrator import context - +from aria.storage.sql_mapi import SQLAlchemyModelAPI +from tests import storage as test_storage from tests.mock import models -from tests.storage import InMemoryModelDriver class TestWorkflowContext(object): def test_execution_creation_on_workflow_context_creation(self, storage): - self._create_ctx(storage) - execution = storage.execution.get(models.EXECUTION_ID) - assert execution.id == models.EXECUTION_ID - assert execution.deployment_id == models.DEPLOYMENT_ID - assert execution.workflow_id == models.WORKFLOW_ID - assert execution.blueprint_id == models.BLUEPRINT_ID + ctx = self._create_ctx(storage) + execution = storage.execution.get(ctx.execution.id) # pylint: disable=no-member + assert execution.deployment == storage.deployment.get_by_name(models.DEPLOYMENT_NAME) + assert execution.workflow_name == models.WORKFLOW_NAME + assert execution.blueprint == storage.blueprint.get_by_name(models.BLUEPRINT_NAME) assert execution.status == storage.execution.model_cls.PENDING assert execution.parameters == {} assert execution.created_at <= datetime.utcnow() @@ -43,13 +42,17 @@ class TestWorkflowContext(object): @staticmethod def _create_ctx(storage): + """ + + :param storage: + :return WorkflowContext: + """ return context.workflow.WorkflowContext( name='simple_context', model_storage=storage, resource_storage=None, - deployment_id=models.DEPLOYMENT_ID, - workflow_id=models.WORKFLOW_ID, - execution_id=models.EXECUTION_ID, + deployment_id=storage.deployment.get_by_name(models.DEPLOYMENT_NAME).id, + workflow_name=models.WORKFLOW_NAME, task_max_attempts=models.TASK_MAX_ATTEMPTS, task_retry_interval=models.TASK_RETRY_INTERVAL ) @@ -57,8 +60,10 @@ class TestWorkflowContext(object): @pytest.fixture(scope='function') def storage(): - result = application_model_storage(InMemoryModelDriver()) - result.setup() - result.blueprint.store(models.get_blueprint()) - result.deployment.store(models.get_deployment()) - return result + api_kwargs = test_storage.get_sqlite_api_kwargs() + workflow_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs) + workflow_storage.blueprint.put(models.get_blueprint()) + blueprint = workflow_storage.blueprint.get_by_name(models.BLUEPRINT_NAME) + workflow_storage.deployment.put(models.get_deployment(blueprint)) + yield workflow_storage + test_storage.release_sqlite_storage(workflow_storage)
